Commit 8f3d9855 authored by Matthew Dempsky's avatar Matthew Dempsky

cmd/compile: major refactoring of switch walking

There are a lot of complexities to handling switches efficiently:

1. Order matters for expression switches with non-constant cases and
for type expressions with interface types. We have to respect
side-effects, and we also can't allow later cases to accidentally take
precedence over earlier cases.

2. For runs of integers, floats, and string constants in expression
switches or runs of concrete types in type switches, we want to emit
efficient binary searches.

3. For runs of consecutive integers in expression switches, we want to
collapse them into range comparisons.

4. For binary searches of strings, we want to compare by length first,
because that's more efficient and we don't need to respect any
particular ordering.

5. For "switch true { ... }" and "switch false { ... }", we want to
optimize "case x:" as simply "if x" or "if !x", respectively, unless x
is interface-typed.

The current swt.go code reflects how these constraints have been
incrementally added over time, with each of them being handled ad
hocly in different parts of the code. Also, the existing code tries
very hard to reuse logic between expression and type switches, even
though the similarities are very superficial.

This CL rewrites switch handling to better abstract away the logic
involved in constructing the binary searches. In particular, it's
intended to make further optimizations to switch dispatch much easier.

It also eliminates the need for both OXCASE and OCASE ops, and a
subsequent CL can collapse the two.

Passes toolstash-check.

Change-Id: Ifcd1e56f81f858117a412971d82e98abe7c4481f
Reviewed-on: https://go-review.googlesource.com/c/go/+/194660
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarKeith Randall <khr@golang.org>
parent 115e4c9c
...@@ -10,50 +10,6 @@ import ( ...@@ -10,50 +10,6 @@ import (
"sort" "sort"
) )
const (
// expression switch
switchKindExpr = iota // switch a {...} or switch 5 {...}
switchKindTrue // switch true {...} or switch {...}
switchKindFalse // switch false {...}
)
const (
binarySearchMin = 4 // minimum number of cases for binary search
integerRangeMin = 2 // minimum size of integer ranges
)
// An exprSwitch walks an expression switch.
type exprSwitch struct {
exprname *Node // node for the expression being switched on
kind int // kind of switch statement (switchKind*)
}
// A typeSwitch walks a type switch.
type typeSwitch struct {
hashname *Node // node for the hash of the type of the variable being switched on
facename *Node // node for the concrete type of the variable being switched on
okname *Node // boolean node used for comma-ok type assertions
}
// A caseClause is a single case clause in a switch statement.
type caseClause struct {
node *Node // points at case statement
ordinal int // position in switch
hash uint32 // hash of a type switch
// isconst indicates whether this case clause is a constant,
// for the purposes of the switch code generation.
// For expression switches, that's generally literals (case 5:, not case x:).
// For type switches, that's concrete types (case time.Time:), not interfaces (case io.Reader:).
isconst bool
}
// caseClauses are all the case clauses in a switch statement.
type caseClauses struct {
list []caseClause // general cases
defjmp *Node // OGOTO for default case or OBREAK if no default case present
niljmp *Node // OGOTO for nil type case in a type switch
}
// typecheckswitch typechecks a switch statement. // typecheckswitch typechecks a switch statement.
func typecheckswitch(n *Node) { func typecheckswitch(n *Node) {
typecheckslice(n.Ninit.Slice(), ctxStmt) typecheckslice(n.Ninit.Slice(), ctxStmt)
...@@ -71,7 +27,6 @@ func typecheckTypeSwitch(n *Node) { ...@@ -71,7 +27,6 @@ func typecheckTypeSwitch(n *Node) {
yyerrorl(n.Pos, "cannot type switch on non-interface value %L", n.Left.Right) yyerrorl(n.Pos, "cannot type switch on non-interface value %L", n.Left.Right)
t = nil t = nil
} }
n.Type = t // TODO(mdempsky): Remove; statements aren't typed.
// We don't actually declare the type switch's guarded // We don't actually declare the type switch's guarded
// declaration itself. So if there are no cases, we won't // declaration itself. So if there are no cases, we won't
...@@ -212,7 +167,6 @@ func typecheckExprSwitch(n *Node) { ...@@ -212,7 +167,6 @@ func typecheckExprSwitch(n *Node) {
t = nil t = nil
} }
} }
n.Type = t // TODO(mdempsky): Remove; statements aren't typed.
var defCase *Node var defCase *Node
var cs constSet var cs constSet
...@@ -265,422 +219,267 @@ func typecheckExprSwitch(n *Node) { ...@@ -265,422 +219,267 @@ func typecheckExprSwitch(n *Node) {
// walkswitch walks a switch statement. // walkswitch walks a switch statement.
func walkswitch(sw *Node) { func walkswitch(sw *Node) {
// convert switch {...} to switch true {...} // Guard against double walk, see #25776.
if sw.Left == nil { if sw.List.Len() == 0 && sw.Nbody.Len() > 0 {
sw.Left = nodbool(true) return // Was fatal, but eliminating every possible source of double-walking is hard
sw.Left = typecheck(sw.Left, ctxExpr)
sw.Left = defaultlit(sw.Left, nil)
} }
if sw.Left.Op == OTYPESW { if sw.Left != nil && sw.Left.Op == OTYPESW {
var s typeSwitch walkTypeSwitch(sw)
s.walk(sw)
} else { } else {
var s exprSwitch walkExprSwitch(sw)
s.walk(sw)
} }
} }
// walk generates an AST implementing sw. // walkExprSwitch generates an AST implementing sw. sw is an
// sw is an expression switch. // expression switch.
// The AST is generally of the form of a linear func walkExprSwitch(sw *Node) {
// search using if..goto, although binary search lno := setlineno(sw)
// is used with long runs of constants.
func (s *exprSwitch) walk(sw *Node) {
// Guard against double walk, see #25776.
if sw.List.Len() == 0 && sw.Nbody.Len() > 0 {
return // Was fatal, but eliminating every possible source of double-walking is hard
}
casebody(sw, nil)
cond := sw.Left cond := sw.Left
sw.Left = nil sw.Left = nil
s.kind = switchKindExpr // convert switch {...} to switch true {...}
if Isconst(cond, CTBOOL) { if cond == nil {
s.kind = switchKindTrue cond = nodbool(true)
if !cond.Val().U.(bool) { cond = typecheck(cond, ctxExpr)
s.kind = switchKindFalse cond = defaultlit(cond, nil)
}
} }
// Given "switch string(byteslice)", // Given "switch string(byteslice)",
// with all cases being constants (or the default case), // with all cases being side-effect free,
// use a zero-cost alias of the byte slice. // use a zero-cost alias of the byte slice.
// In theory, we could be more aggressive,
// allowing any side-effect-free expressions in cases,
// but it's a bit tricky because some of that information
// is unavailable due to the introduction of temporaries during order.
// Restricting to constants is simple and probably powerful enough.
// Do this before calling walkexpr on cond, // Do this before calling walkexpr on cond,
// because walkexpr will lower the string // because walkexpr will lower the string
// conversion into a runtime call. // conversion into a runtime call.
// See issue 24937 for more discussion. // See issue 24937 for more discussion.
if cond.Op == OBYTES2STR { if cond.Op == OBYTES2STR && allCaseExprsAreSideEffectFree(sw) {
ok := true cond.Op = OBYTES2STRTMP
for _, cas := range sw.List.Slice() {
if cas.Op != OCASE {
Fatalf("switch string(byteslice) bad op: %v", cas.Op)
}
if cas.Left != nil && !Isconst(cas.Left, CTSTR) {
ok = false
break
}
}
if ok {
cond.Op = OBYTES2STRTMP
}
} }
cond = walkexpr(cond, &sw.Ninit) cond = walkexpr(cond, &sw.Ninit)
t := sw.Type if cond.Op != OLITERAL {
if t == nil { cond = copyexpr(cond, cond.Type, &sw.Nbody)
return
} }
// convert the switch into OIF statements lineno = lno
var cas []*Node
if s.kind == switchKindTrue || s.kind == switchKindFalse { s := exprSwitch{
s.exprname = nodbool(s.kind == switchKindTrue) exprname: cond,
} else if consttype(cond) > 0 {
// leave constants to enable dead code elimination (issue 9608)
s.exprname = cond
} else {
s.exprname = temp(cond.Type)
cas = []*Node{nod(OAS, s.exprname, cond)} // This gets walk()ed again in walkstmtlist just before end of this function. See #29562.
typecheckslice(cas, ctxStmt)
} }
// Enumerate the cases and prepare the default case. br := nod(OBREAK, nil, nil)
clauses := s.genCaseClauses(sw.List.Slice()) var defaultGoto *Node
sw.List.Set(nil) var body Nodes
cc := clauses.list for _, ncase := range sw.List.Slice() {
label := autolabel(".s")
// handle the cases in order jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label))
for len(cc) > 0 {
run := 1 // Process case dispatch.
if okforcmp[t.Etype] && cc[0].isconst { if ncase.List.Len() == 0 {
// do binary search on runs of constants if defaultGoto != nil {
for ; run < len(cc) && cc[run].isconst; run++ { Fatalf("duplicate default case not detected during typechecking")
} }
// sort and compile constants defaultGoto = jmp
sort.Sort(caseClauseByConstVal(cc[:run]))
} }
a := s.walkCases(cc[:run]) for _, n1 := range ncase.List.Slice() {
cas = append(cas, a) s.Add(ncase.Pos, n1, jmp)
cc = cc[run:] }
// Process body.
body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label)))
body.Append(ncase.Nbody.Slice()...)
if !hasFall(ncase.Nbody.Slice()) {
body.Append(br)
}
} }
sw.List.Set(nil)
// handle default case if defaultGoto == nil {
if nerrors == 0 { defaultGoto = br
cas = append(cas, clauses.defjmp)
sw.Nbody.Prepend(cas...)
walkstmtlist(sw.Nbody.Slice())
} }
}
// walkCases generates an AST implementing the cases in cc. s.Emit(&sw.Nbody)
func (s *exprSwitch) walkCases(cc []caseClause) *Node { sw.Nbody.Append(defaultGoto)
if len(cc) < binarySearchMin { sw.Nbody.AppendNodes(&body)
// linear search walkstmtlist(sw.Nbody.Slice())
var cas []*Node }
for _, c := range cc {
n := c.node
lno := setlineno(n)
a := nod(OIF, nil, nil) // An exprSwitch walks an expression switch.
if rng := n.List.Slice(); rng != nil { type exprSwitch struct {
// Integer range. exprname *Node // value being switched on
// exprname is a temp or a constant,
// so it is safe to evaluate twice.
// In most cases, this conjunction will be
// rewritten by walkinrange into a single comparison.
low := nod(OGE, s.exprname, rng[0])
high := nod(OLE, s.exprname, rng[1])
a.Left = nod(OANDAND, low, high)
} else if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE {
a.Left = nod(OEQ, s.exprname, n.Left) // if name == val
} else if s.kind == switchKindTrue {
a.Left = n.Left // if val
} else {
// s.kind == switchKindFalse
a.Left = nod(ONOT, n.Left, nil) // if !val
}
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(n.Right) // goto l
cas = append(cas, a) done Nodes
lineno = lno clauses []exprClause
} }
return liststmt(cas)
}
// find the middle and recur type exprClause struct {
half := len(cc) / 2 pos src.XPos
a := nod(OIF, nil, nil) lo, hi *Node
n := cc[half-1].node jmp *Node
var mid *Node
if rng := n.List.Slice(); rng != nil {
mid = rng[1] // high end of range
} else {
mid = n.Left
}
le := nod(OLE, s.exprname, mid)
if Isconst(mid, CTSTR) {
// Search by length and then by value; see caseClauseByConstVal.
lenlt := nod(OLT, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
leneq := nod(OEQ, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
a.Left = nod(OOROR, lenlt, nod(OANDAND, leneq, le))
} else {
a.Left = le
}
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(s.walkCases(cc[:half]))
a.Rlist.Set1(s.walkCases(cc[half:]))
return a
} }
// casebody builds separate lists of statements and cases. func (s *exprSwitch) Add(pos src.XPos, expr, jmp *Node) {
// It makes labels between cases and statements c := exprClause{pos: pos, lo: expr, hi: expr, jmp: jmp}
// and deals with fallthrough, break, and unreachable statements. if okforcmp[s.exprname.Type.Etype] && expr.Op == OLITERAL {
func casebody(sw *Node, typeswvar *Node) { s.clauses = append(s.clauses, c)
if sw.List.Len() == 0 {
return return
} }
lno := setlineno(sw) s.flush()
s.clauses = append(s.clauses, c)
s.flush()
}
var cas []*Node // cases func (s *exprSwitch) Emit(out *Nodes) {
var stat []*Node // statements s.flush()
var def *Node // defaults out.AppendNodes(&s.done)
br := nod(OBREAK, nil, nil) }
for _, n := range sw.List.Slice() { func (s *exprSwitch) flush() {
setlineno(n) cc := s.clauses
if n.Op != OXCASE { s.clauses = nil
Fatalf("casebody %v", n.Op) if len(cc) == 0 {
} return
n.Op = OCASE }
needvar := n.List.Len() != 1 || n.List.First().Op == OLITERAL
// Caution: If len(cc) == 1, then cc[0] might not an OLITERAL.
lbl := autolabel(".s") // The code below is structured to implicitly handle this case
jmp := nodSym(OGOTO, nil, lbl) // (e.g., sort.Slice doesn't need to invoke the less function
switch n.List.Len() { // when there's only a single slice element).
case 0:
// default // Sort strings by length and then by value.
if def != nil { // It is much cheaper to compare lengths than values,
yyerrorl(n.Pos, "more than one default case") // and all we need here is consistency.
} // We respect this sorting below.
// reuse original default case sort.Slice(cc, func(i, j int) bool {
n.Right = jmp vi := cc[i].lo.Val()
def = n vj := cc[j].lo.Val()
case 1:
// one case -- reuse OCASE node if s.exprname.Type.IsString() {
n.Left = n.List.First() si := vi.U.(string)
n.Right = jmp sj := vj.U.(string)
n.List.Set(nil) if len(si) != len(sj) {
cas = append(cas, n) return len(si) < len(sj)
default:
// Expand multi-valued cases and detect ranges of integer cases.
if typeswvar != nil || sw.Left.Type.IsInterface() || !n.List.First().Type.IsInteger() || n.List.Len() < integerRangeMin {
// Can't use integer ranges. Expand each case into a separate node.
for _, n1 := range n.List.Slice() {
cas = append(cas, nod(OCASE, n1, jmp))
}
break
}
// Find integer ranges within runs of constants.
s := n.List.Slice()
j := 0
for j < len(s) {
// Find a run of constants.
var run int
for run = j; run < len(s) && Isconst(s[run], CTINT); run++ {
}
if run-j >= integerRangeMin {
// Search for integer ranges in s[j:run].
// Typechecking is done, so all values are already in an appropriate range.
search := s[j:run]
sort.Sort(constIntNodesByVal(search))
for beg, end := 0, 1; end <= len(search); end++ {
if end < len(search) && search[end].Int64() == search[end-1].Int64()+1 {
continue
}
if end-beg >= integerRangeMin {
// Record range in List.
c := nod(OCASE, nil, jmp)
c.List.Set2(search[beg], search[end-1])
cas = append(cas, c)
} else {
// Not large enough for range; record separately.
for _, n := range search[beg:end] {
cas = append(cas, nod(OCASE, n, jmp))
}
}
beg = end
}
j = run
}
// Advance to next constant, adding individual non-constant
// or as-yet-unhandled constant cases as we go.
for ; j < len(s) && (j < run || !Isconst(s[j], CTINT)); j++ {
cas = append(cas, nod(OCASE, s[j], jmp))
}
} }
return si < sj
} }
stat = append(stat, nodSym(OLABEL, nil, lbl)) return compareOp(vi, OLT, vj)
if typeswvar != nil && needvar && n.Rlist.Len() != 0 { })
l := []*Node{
nod(ODCL, n.Rlist.First(), nil), // Merge consecutive integer cases.
nod(OAS, n.Rlist.First(), typeswvar), if s.exprname.Type.IsInteger() {
merged := cc[:1]
for _, c := range cc[1:] {
last := &merged[len(merged)-1]
if last.jmp == c.jmp && last.hi.Int64()+1 == c.lo.Int64() {
last.hi = c.lo
} else {
merged = append(merged, c)
} }
typecheckslice(l, ctxStmt)
stat = append(stat, l...)
}
stat = append(stat, n.Nbody.Slice()...)
// Search backwards for the index of the fallthrough
// statement. Do not assume it'll be in the last
// position, since in some cases (e.g. when the statement
// list contains autotmp_ variables), one or more OVARKILL
// nodes will be at the end of the list.
fallIndex := len(stat) - 1
for stat[fallIndex].Op == OVARKILL {
fallIndex--
}
last := stat[fallIndex]
if last.Op != OFALL {
stat = append(stat, br)
} }
cc = merged
} }
stat = append(stat, br) binarySearch(len(cc), &s.done,
if def != nil { func(i int) *Node {
cas = append(cas, def) mid := cc[i-1].hi
}
sw.List.Set(cas) le := nod(OLE, s.exprname, mid)
sw.Nbody.Set(stat) if s.exprname.Type.IsString() {
lineno = lno // Compare strings by length and then
// by value; see sort.Slice above.
lenlt := nod(OLT, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
leneq := nod(OEQ, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
le = nod(OOROR, lenlt, nod(OANDAND, leneq, le))
}
return le
},
func(i int, out *Nodes) {
c := &cc[i]
nif := nodl(c.pos, OIF, c.test(s.exprname), nil)
nif.Left = typecheck(nif.Left, ctxExpr)
nif.Left = defaultlit(nif.Left, nil)
nif.Nbody.Set1(c.jmp)
out.Append(nif)
},
)
} }
// genCaseClauses generates the caseClauses value for clauses. func (c *exprClause) test(exprname *Node) *Node {
func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses { // Integer range.
var cc caseClauses if c.hi != c.lo {
for _, n := range clauses { low := nodl(c.pos, OGE, exprname, c.lo)
if n.Left == nil && n.List.Len() == 0 { high := nodl(c.pos, OLE, exprname, c.hi)
// default case return nodl(c.pos, OANDAND, low, high)
if cc.defjmp != nil {
Fatalf("duplicate default case not detected during typechecking")
}
cc.defjmp = n.Right
continue
}
c := caseClause{node: n, ordinal: len(cc.list)}
if n.List.Len() > 0 {
c.isconst = true
}
switch consttype(n.Left) {
case CTFLT, CTINT, CTRUNE, CTSTR:
c.isconst = true
}
cc.list = append(cc.list, c)
} }
if cc.defjmp == nil { // Optimize "switch true { ...}" and "switch false { ... }".
cc.defjmp = nod(OBREAK, nil, nil) if Isconst(exprname, CTBOOL) && !c.lo.Type.IsInterface() {
if exprname.Val().U.(bool) {
return c.lo
} else {
return nodl(c.pos, ONOT, c.lo, nil)
}
} }
return cc
return nodl(c.pos, OEQ, exprname, c.lo)
} }
// genCaseClauses generates the caseClauses value for clauses. func allCaseExprsAreSideEffectFree(sw *Node) bool {
func (s *typeSwitch) genCaseClauses(clauses []*Node) caseClauses { // In theory, we could be more aggressive, allowing any
var cc caseClauses // side-effect-free expressions in cases, but it's a bit
for _, n := range clauses { // tricky because some of that information is unavailable due
switch { // to the introduction of temporaries during order.
case n.Left == nil: // Restricting to constants is simple and probably powerful
// default case // enough.
if cc.defjmp != nil {
Fatalf("duplicate default case not detected during typechecking") for _, ncase := range sw.List.Slice() {
} if ncase.Op != OXCASE {
cc.defjmp = n.Right Fatalf("switch string(byteslice) bad op: %v", ncase.Op)
continue
case n.Left.Op == OLITERAL:
// nil case in type switch
if cc.niljmp != nil {
Fatalf("duplicate nil case not detected during typechecking")
}
cc.niljmp = n.Right
continue
} }
for _, v := range ncase.List.Slice() {
// general case if v.Op != OLITERAL {
c := caseClause{ return false
node: n, }
ordinal: len(cc.list),
isconst: !n.Left.Type.IsInterface(),
hash: typehash(n.Left.Type),
} }
cc.list = append(cc.list, c)
}
if cc.defjmp == nil {
cc.defjmp = nod(OBREAK, nil, nil)
} }
return true
return cc
} }
// walk generates an AST that implements sw, // hasFall reports whether stmts ends with a "fallthrough" statement.
// where sw is a type switch. func hasFall(stmts []*Node) bool {
// The AST is generally of the form of a linear // Search backwards for the index of the fallthrough
// search using if..goto, although binary search // statement. Do not assume it'll be in the last
// is used with long runs of concrete types. // position, since in some cases (e.g. when the statement
func (s *typeSwitch) walk(sw *Node) { // list contains autotmp_ variables), one or more OVARKILL
cond := sw.Left // nodes will be at the end of the list.
sw.Left = nil
if cond == nil { i := len(stmts) - 1
sw.List.Set(nil) for i >= 0 && stmts[i].Op == OVARKILL {
return i--
} }
if cond.Right == nil { return i >= 0 && stmts[i].Op == OFALL
yyerrorl(sw.Pos, "type switch must have an assignment") }
return
}
cond.Right = walkexpr(cond.Right, &sw.Ninit)
if !cond.Right.Type.IsInterface() {
yyerrorl(sw.Pos, "type switch must be on an interface")
return
}
var cas []*Node
// predeclare temporary variables and the boolean var
s.facename = temp(cond.Right.Type)
a := nod(OAS, s.facename, cond.Right) // walkTypeSwitch generates an AST that implements sw, where sw is a
a = typecheck(a, ctxStmt) // type switch.
cas = append(cas, a) func walkTypeSwitch(sw *Node) {
var s typeSwitch
s.facename = sw.Left.Right
sw.Left = nil
s.facename = walkexpr(s.facename, &sw.Ninit)
s.facename = copyexpr(s.facename, s.facename.Type, &sw.Nbody)
s.okname = temp(types.Types[TBOOL]) s.okname = temp(types.Types[TBOOL])
s.okname = typecheck(s.okname, ctxExpr)
s.hashname = temp(types.Types[TUINT32])
s.hashname = typecheck(s.hashname, ctxExpr)
// set up labels and jumps // Get interface descriptor word.
casebody(sw, s.facename) // For empty interfaces this will be the type.
// For non-empty interfaces this will be the itab.
clauses := s.genCaseClauses(sw.List.Slice()) itab := nod(OITAB, s.facename, nil)
sw.List.Set(nil)
def := clauses.defjmp
// For empty interfaces, do: // For empty interfaces, do:
// if e._type == nil { // if e._type == nil {
...@@ -688,230 +487,235 @@ func (s *typeSwitch) walk(sw *Node) { ...@@ -688,230 +487,235 @@ func (s *typeSwitch) walk(sw *Node) {
// } // }
// h := e._type.hash // h := e._type.hash
// Use a similar strategy for non-empty interfaces. // Use a similar strategy for non-empty interfaces.
ifNil := nod(OIF, nil, nil)
// Get interface descriptor word. ifNil.Left = nod(OEQ, itab, nodnil())
// For empty interfaces this will be the type. ifNil.Left = typecheck(ifNil.Left, ctxExpr)
// For non-empty interfaces this will be the itab. ifNil.Left = defaultlit(ifNil.Left, nil)
itab := nod(OITAB, s.facename, nil) // ifNil.Nbody assigned at end.
sw.Nbody.Append(ifNil)
// Check for nil first.
i := nod(OIF, nil, nil)
i.Left = nod(OEQ, itab, nodnil())
if clauses.niljmp != nil {
// Do explicit nil case right here.
i.Nbody.Set1(clauses.niljmp)
} else {
// Jump to default case.
lbl := autolabel(".s")
i.Nbody.Set1(nodSym(OGOTO, nil, lbl))
// Wrap default case with label.
blk := nod(OBLOCK, nil, nil)
blk.List.Set2(nodSym(OLABEL, nil, lbl), def)
def = blk
}
i.Left = typecheck(i.Left, ctxExpr)
i.Left = defaultlit(i.Left, nil)
cas = append(cas, i)
// Load hash from type or itab. // Load hash from type or itab.
h := nodSym(ODOTPTR, itab, nil) dotHash := nodSym(ODOTPTR, itab, nil)
h.Type = types.Types[TUINT32] dotHash.Type = types.Types[TUINT32]
h.SetTypecheck(1) dotHash.SetTypecheck(1)
if cond.Right.Type.IsEmptyInterface() { if s.facename.Type.IsEmptyInterface() {
h.Xoffset = int64(2 * Widthptr) // offset of hash in runtime._type dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime._type
} else { } else {
h.Xoffset = int64(2 * Widthptr) // offset of hash in runtime.itab dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime.itab
} }
h.SetBounded(true) // guaranteed not to fault dotHash.SetBounded(true) // guaranteed not to fault
a = nod(OAS, s.hashname, h) s.hashname = copyexpr(dotHash, dotHash.Type, &sw.Nbody)
a = typecheck(a, ctxStmt)
cas = append(cas, a)
cc := clauses.list br := nod(OBREAK, nil, nil)
var defaultGoto, nilGoto *Node
var body Nodes
for _, ncase := range sw.List.Slice() {
var caseVar *Node
if ncase.Rlist.Len() != 0 {
caseVar = ncase.Rlist.First()
}
// insert type equality check into each case block // For single-type cases, we initialize the case
for _, c := range cc { // variable as part of the type assertion; but in
c.node.Right = s.typeone(c.node) // other cases, we initialize it in the body.
} singleType := ncase.List.Len() == 1 && ncase.List.First().Op == OTYPE
// generate list of if statements, binary search for constant sequences label := autolabel(".s")
for len(cc) > 0 {
if !cc[0].isconst {
n := cc[0].node
cas = append(cas, n.Right)
cc = cc[1:]
continue
}
// identify run of constants jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label))
var run int if ncase.List.Len() == 0 { // default:
for run = 1; run < len(cc) && cc[run].isconst; run++ { if defaultGoto != nil {
Fatalf("duplicate default case not detected during typechecking")
}
defaultGoto = jmp
} }
// sort by hash for _, n1 := range ncase.List.Slice() {
sort.Sort(caseClauseByType(cc[:run])) if n1.isNil() { // case nil:
if nilGoto != nil {
Fatalf("duplicate nil case not detected during typechecking")
}
nilGoto = jmp
continue
}
// for debugging: linear search if singleType {
if false { s.Add(n1.Type, caseVar, jmp)
for i := 0; i < run; i++ { } else {
n := cc[i].node s.Add(n1.Type, nil, jmp)
cas = append(cas, n.Right)
} }
continue
} }
// combine adjacent cases with the same hash body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label)))
var batch []caseClause if caseVar != nil && !singleType {
for i, j := 0, 0; i < run; i = j { l := []*Node{
hash := []*Node{cc[i].node.Right} nodl(ncase.Pos, ODCL, caseVar, nil),
for j = i + 1; j < run && cc[i].hash == cc[j].hash; j++ { nodl(ncase.Pos, OAS, caseVar, s.facename),
hash = append(hash, cc[j].node.Right)
} }
cc[i].node.Right = liststmt(hash) typecheckslice(l, ctxStmt)
batch = append(batch, cc[i]) body.Append(l...)
} }
body.Append(ncase.Nbody.Slice()...)
body.Append(br)
}
sw.List.Set(nil)
// binary search among cases to narrow by hash if defaultGoto == nil {
cas = append(cas, s.walkCases(batch)) defaultGoto = br
cc = cc[run:]
} }
// handle default case if nilGoto != nil {
if nerrors == 0 { ifNil.Nbody.Set1(nilGoto)
cas = append(cas, def) } else {
sw.Nbody.Prepend(cas...) // TODO(mdempsky): Just use defaultGoto directly.
sw.List.Set(nil)
walkstmtlist(sw.Nbody.Slice()) // Jump to default case.
label := autolabel(".s")
ifNil.Nbody.Set1(nodSym(OGOTO, nil, label))
// Wrap default case with label.
blk := nod(OBLOCK, nil, nil)
blk.List.Set2(nodSym(OLABEL, nil, label), defaultGoto)
defaultGoto = blk
} }
s.Emit(&sw.Nbody)
sw.Nbody.Append(defaultGoto)
sw.Nbody.AppendNodes(&body)
walkstmtlist(sw.Nbody.Slice())
} }
// typeone generates an AST that jumps to the // A typeSwitch walks a type switch.
// case body if the variable is of type t. type typeSwitch struct {
func (s *typeSwitch) typeone(t *Node) *Node { // Temporary variables (i.e., ONAMEs) used by type switch dispatch logic:
var name *Node facename *Node // value being type-switched on
var init Nodes hashname *Node // type hash of the value being type-switched on
if t.Rlist.Len() == 0 { okname *Node // boolean used for comma-ok type assertions
name = nblank
nblank = typecheck(nblank, ctxExpr|ctxAssign) done Nodes
} else { clauses []typeClause
name = t.Rlist.First()
init.Append(nod(ODCL, name, nil))
a := nod(OAS, name, nil)
a = typecheck(a, ctxStmt)
init.Append(a)
}
a := nod(OAS2, nil, nil)
a.List.Set2(name, s.okname) // name, ok =
b := nod(ODOTTYPE, s.facename, nil)
b.Type = t.Left.Type // interface.(type)
a.Rlist.Set1(b)
a = typecheck(a, ctxStmt)
a = walkexpr(a, &init)
init.Append(a)
c := nod(OIF, nil, nil)
c.Left = s.okname
c.Nbody.Set1(t.Right) // if ok { goto l }
init.Append(c)
return init.asblock()
} }
// walkCases generates an AST implementing the cases in cc. type typeClause struct {
func (s *typeSwitch) walkCases(cc []caseClause) *Node { hash uint32
if len(cc) < binarySearchMin { body Nodes
var cas []*Node
for _, c := range cc {
n := c.node
if !c.isconst {
Fatalf("typeSwitch walkCases")
}
a := nod(OIF, nil, nil)
a.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash)))
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(n.Right)
cas = append(cas, a)
}
return liststmt(cas)
}
// find the middle and recur
half := len(cc) / 2
a := nod(OIF, nil, nil)
a.Left = nod(OLE, s.hashname, nodintconst(int64(cc[half-1].hash)))
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(s.walkCases(cc[:half]))
a.Rlist.Set1(s.walkCases(cc[half:]))
return a
} }
// caseClauseByConstVal sorts clauses by constant value to enable binary search. func (s *typeSwitch) Add(typ *types.Type, caseVar *Node, jmp *Node) {
type caseClauseByConstVal []caseClause var body Nodes
if caseVar != nil {
func (x caseClauseByConstVal) Len() int { return len(x) } l := []*Node{
func (x caseClauseByConstVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] } nod(ODCL, caseVar, nil),
func (x caseClauseByConstVal) Less(i, j int) bool { nod(OAS, caseVar, nil),
// n1 and n2 might be individual constants or integer ranges. }
// We have checked for duplicates already, typecheckslice(l, ctxStmt)
// so ranges can be safely represented by any value in the range. body.Append(l...)
n1 := x[i].node
var v1 interface{}
if s := n1.List.Slice(); s != nil {
v1 = s[0].Val().U
} else { } else {
v1 = n1.Left.Val().U caseVar = nblank
}
// cv, ok = iface.(type)
as := nod(OAS2, nil, nil)
as.List.Set2(caseVar, s.okname) // cv, ok =
dot := nod(ODOTTYPE, s.facename, nil)
dot.Type = typ // iface.(type)
as.Rlist.Set1(dot)
as = typecheck(as, ctxStmt)
as = walkexpr(as, &body)
body.Append(as)
// if ok { goto label }
nif := nod(OIF, nil, nil)
nif.Left = s.okname
nif.Nbody.Set1(jmp)
body.Append(nif)
if !typ.IsInterface() {
s.clauses = append(s.clauses, typeClause{
hash: typehash(typ),
body: body,
})
return
} }
n2 := x[j].node s.flush()
var v2 interface{} s.done.AppendNodes(&body)
if s := n2.List.Slice(); s != nil { }
v2 = s[0].Val().U
} else {
v2 = n2.Left.Val().U
}
switch v1 := v1.(type) {
case *Mpflt:
return v1.Cmp(v2.(*Mpflt)) < 0
case *Mpint:
return v1.Cmp(v2.(*Mpint)) < 0
case string:
// Sort strings by length and then by value.
// It is much cheaper to compare lengths than values,
// and all we need here is consistency.
// We respect this sorting in exprSwitch.walkCases.
a := v1
b := v2.(string)
if len(a) != len(b) {
return len(a) < len(b)
}
return a < b
}
Fatalf("caseClauseByConstVal passed bad clauses %v < %v", x[i].node.Left, x[j].node.Left) func (s *typeSwitch) Emit(out *Nodes) {
return false s.flush()
out.AppendNodes(&s.done)
} }
type caseClauseByType []caseClause func (s *typeSwitch) flush() {
cc := s.clauses
s.clauses = nil
if len(cc) == 0 {
return
}
sort.Slice(cc, func(i, j int) bool { return cc[i].hash < cc[j].hash })
func (x caseClauseByType) Len() int { return len(x) } // Combine adjacent cases with the same hash.
func (x caseClauseByType) Swap(i, j int) { x[i], x[j] = x[j], x[i] } merged := cc[:1]
func (x caseClauseByType) Less(i, j int) bool { for _, c := range cc[1:] {
c1, c2 := x[i], x[j] last := &merged[len(merged)-1]
// sort by hash code, then ordinal (for the rare case of hash collisions) if last.hash == c.hash {
if c1.hash != c2.hash { last.body.AppendNodes(&c.body)
return c1.hash < c2.hash } else {
merged = append(merged, c)
}
} }
return c1.ordinal < c2.ordinal cc = merged
binarySearch(len(cc), &s.done,
func(i int) *Node {
return nod(OLE, s.hashname, nodintconst(int64(cc[i-1].hash)))
},
func(i int, out *Nodes) {
// TODO(mdempsky): Omit hash equality check if
// there's only one type.
c := cc[i]
a := nod(OIF, nil, nil)
a.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash)))
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.AppendNodes(&c.body)
out.Append(a)
},
)
} }
type constIntNodesByVal []*Node // binarySearch constructs a binary search tree for handling n cases,
// and appends it to out. It's used for efficiently implementing
// switch statements.
//
// less(i) should return a boolean expression. If it evaluates true,
// then cases [0, i) will be tested; otherwise, cases [i, n).
//
// base(i, out) should append statements to out to test the i'th case.
func binarySearch(n int, out *Nodes, less func(i int) *Node, base func(i int, out *Nodes)) {
const binarySearchMin = 4 // minimum number of cases for binary search
var do func(lo, hi int, out *Nodes)
do = func(lo, hi int, out *Nodes) {
n := hi - lo
if n < binarySearchMin {
for i := lo; i < hi; i++ {
base(i, out)
}
return
}
half := lo + n/2
nif := nod(OIF, nil, nil)
nif.Left = less(half)
nif.Left = typecheck(nif.Left, ctxExpr)
nif.Left = defaultlit(nif.Left, nil)
do(lo, half, &nif.Nbody)
do(half, hi, &nif.Rlist)
out.Append(nif)
}
func (x constIntNodesByVal) Len() int { return len(x) } do(0, n, out)
func (x constIntNodesByVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x constIntNodesByVal) Less(i, j int) bool {
return x[i].Val().U.(*Mpint).Cmp(x[j].Val().U.(*Mpint)) < 0
} }
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gc
import (
"testing"
)
func nodrune(r rune) *Node {
v := new(Mpint)
v.SetInt64(int64(r))
v.Rune = true
return nodlit(Val{v})
}
func nodflt(f float64) *Node {
v := newMpflt()
v.SetFloat64(f)
return nodlit(Val{v})
}
func TestCaseClauseByConstVal(t *testing.T) {
tests := []struct {
a, b *Node
}{
// CTFLT
{nodflt(0.1), nodflt(0.2)},
// CTINT
{nodintconst(0), nodintconst(1)},
// CTRUNE
{nodrune('a'), nodrune('b')},
// CTSTR
{nodlit(Val{"ab"}), nodlit(Val{"abc"})},
{nodlit(Val{"ab"}), nodlit(Val{"xyz"})},
{nodlit(Val{"abc"}), nodlit(Val{"xyz"})},
}
for i, test := range tests {
a := caseClause{node: nod(OXXX, test.a, nil)}
b := caseClause{node: nod(OXXX, test.b, nil)}
s := caseClauseByConstVal{a, b}
if less := s.Less(0, 1); !less {
t.Errorf("%d: caseClauseByConstVal(%v, %v) = false", i, test.a, test.b)
}
if less := s.Less(1, 0); less {
t.Errorf("%d: caseClauseByConstVal(%v, %v) = true", i, test.a, test.b)
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment