Commit 8f3d9855 authored by Matthew Dempsky's avatar Matthew Dempsky

cmd/compile: major refactoring of switch walking

There are a lot of complexities to handling switches efficiently:

1. Order matters for expression switches with non-constant cases and
for type expressions with interface types. We have to respect
side-effects, and we also can't allow later cases to accidentally take
precedence over earlier cases.

2. For runs of integers, floats, and string constants in expression
switches or runs of concrete types in type switches, we want to emit
efficient binary searches.

3. For runs of consecutive integers in expression switches, we want to
collapse them into range comparisons.

4. For binary searches of strings, we want to compare by length first,
because that's more efficient and we don't need to respect any
particular ordering.

5. For "switch true { ... }" and "switch false { ... }", we want to
optimize "case x:" as simply "if x" or "if !x", respectively, unless x
is interface-typed.

The current swt.go code reflects how these constraints have been
incrementally added over time, with each of them being handled ad
hocly in different parts of the code. Also, the existing code tries
very hard to reuse logic between expression and type switches, even
though the similarities are very superficial.

This CL rewrites switch handling to better abstract away the logic
involved in constructing the binary searches. In particular, it's
intended to make further optimizations to switch dispatch much easier.

It also eliminates the need for both OXCASE and OCASE ops, and a
subsequent CL can collapse the two.

Passes toolstash-check.

Change-Id: Ifcd1e56f81f858117a412971d82e98abe7c4481f
Reviewed-on: https://go-review.googlesource.com/c/go/+/194660
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarKeith Randall <khr@golang.org>
parent 115e4c9c
...@@ -10,50 +10,6 @@ import ( ...@@ -10,50 +10,6 @@ import (
"sort" "sort"
) )
const (
// expression switch
switchKindExpr = iota // switch a {...} or switch 5 {...}
switchKindTrue // switch true {...} or switch {...}
switchKindFalse // switch false {...}
)
const (
binarySearchMin = 4 // minimum number of cases for binary search
integerRangeMin = 2 // minimum size of integer ranges
)
// An exprSwitch walks an expression switch.
type exprSwitch struct {
exprname *Node // node for the expression being switched on
kind int // kind of switch statement (switchKind*)
}
// A typeSwitch walks a type switch.
type typeSwitch struct {
hashname *Node // node for the hash of the type of the variable being switched on
facename *Node // node for the concrete type of the variable being switched on
okname *Node // boolean node used for comma-ok type assertions
}
// A caseClause is a single case clause in a switch statement.
type caseClause struct {
node *Node // points at case statement
ordinal int // position in switch
hash uint32 // hash of a type switch
// isconst indicates whether this case clause is a constant,
// for the purposes of the switch code generation.
// For expression switches, that's generally literals (case 5:, not case x:).
// For type switches, that's concrete types (case time.Time:), not interfaces (case io.Reader:).
isconst bool
}
// caseClauses are all the case clauses in a switch statement.
type caseClauses struct {
list []caseClause // general cases
defjmp *Node // OGOTO for default case or OBREAK if no default case present
niljmp *Node // OGOTO for nil type case in a type switch
}
// typecheckswitch typechecks a switch statement. // typecheckswitch typechecks a switch statement.
func typecheckswitch(n *Node) { func typecheckswitch(n *Node) {
typecheckslice(n.Ninit.Slice(), ctxStmt) typecheckslice(n.Ninit.Slice(), ctxStmt)
...@@ -71,7 +27,6 @@ func typecheckTypeSwitch(n *Node) { ...@@ -71,7 +27,6 @@ func typecheckTypeSwitch(n *Node) {
yyerrorl(n.Pos, "cannot type switch on non-interface value %L", n.Left.Right) yyerrorl(n.Pos, "cannot type switch on non-interface value %L", n.Left.Right)
t = nil t = nil
} }
n.Type = t // TODO(mdempsky): Remove; statements aren't typed.
// We don't actually declare the type switch's guarded // We don't actually declare the type switch's guarded
// declaration itself. So if there are no cases, we won't // declaration itself. So if there are no cases, we won't
...@@ -212,7 +167,6 @@ func typecheckExprSwitch(n *Node) { ...@@ -212,7 +167,6 @@ func typecheckExprSwitch(n *Node) {
t = nil t = nil
} }
} }
n.Type = t // TODO(mdempsky): Remove; statements aren't typed.
var defCase *Node var defCase *Node
var cs constSet var cs constSet
...@@ -265,422 +219,267 @@ func typecheckExprSwitch(n *Node) { ...@@ -265,422 +219,267 @@ func typecheckExprSwitch(n *Node) {
// walkswitch walks a switch statement. // walkswitch walks a switch statement.
func walkswitch(sw *Node) { func walkswitch(sw *Node) {
// convert switch {...} to switch true {...} // Guard against double walk, see #25776.
if sw.Left == nil { if sw.List.Len() == 0 && sw.Nbody.Len() > 0 {
sw.Left = nodbool(true) return // Was fatal, but eliminating every possible source of double-walking is hard
sw.Left = typecheck(sw.Left, ctxExpr)
sw.Left = defaultlit(sw.Left, nil)
} }
if sw.Left.Op == OTYPESW { if sw.Left != nil && sw.Left.Op == OTYPESW {
var s typeSwitch walkTypeSwitch(sw)
s.walk(sw)
} else { } else {
var s exprSwitch walkExprSwitch(sw)
s.walk(sw)
} }
} }
// walk generates an AST implementing sw. // walkExprSwitch generates an AST implementing sw. sw is an
// sw is an expression switch. // expression switch.
// The AST is generally of the form of a linear func walkExprSwitch(sw *Node) {
// search using if..goto, although binary search lno := setlineno(sw)
// is used with long runs of constants.
func (s *exprSwitch) walk(sw *Node) {
// Guard against double walk, see #25776.
if sw.List.Len() == 0 && sw.Nbody.Len() > 0 {
return // Was fatal, but eliminating every possible source of double-walking is hard
}
casebody(sw, nil)
cond := sw.Left cond := sw.Left
sw.Left = nil sw.Left = nil
s.kind = switchKindExpr // convert switch {...} to switch true {...}
if Isconst(cond, CTBOOL) { if cond == nil {
s.kind = switchKindTrue cond = nodbool(true)
if !cond.Val().U.(bool) { cond = typecheck(cond, ctxExpr)
s.kind = switchKindFalse cond = defaultlit(cond, nil)
}
} }
// Given "switch string(byteslice)", // Given "switch string(byteslice)",
// with all cases being constants (or the default case), // with all cases being side-effect free,
// use a zero-cost alias of the byte slice. // use a zero-cost alias of the byte slice.
// In theory, we could be more aggressive,
// allowing any side-effect-free expressions in cases,
// but it's a bit tricky because some of that information
// is unavailable due to the introduction of temporaries during order.
// Restricting to constants is simple and probably powerful enough.
// Do this before calling walkexpr on cond, // Do this before calling walkexpr on cond,
// because walkexpr will lower the string // because walkexpr will lower the string
// conversion into a runtime call. // conversion into a runtime call.
// See issue 24937 for more discussion. // See issue 24937 for more discussion.
if cond.Op == OBYTES2STR { if cond.Op == OBYTES2STR && allCaseExprsAreSideEffectFree(sw) {
ok := true
for _, cas := range sw.List.Slice() {
if cas.Op != OCASE {
Fatalf("switch string(byteslice) bad op: %v", cas.Op)
}
if cas.Left != nil && !Isconst(cas.Left, CTSTR) {
ok = false
break
}
}
if ok {
cond.Op = OBYTES2STRTMP cond.Op = OBYTES2STRTMP
} }
}
cond = walkexpr(cond, &sw.Ninit) cond = walkexpr(cond, &sw.Ninit)
t := sw.Type if cond.Op != OLITERAL {
if t == nil { cond = copyexpr(cond, cond.Type, &sw.Nbody)
return
} }
// convert the switch into OIF statements lineno = lno
var cas []*Node
if s.kind == switchKindTrue || s.kind == switchKindFalse { s := exprSwitch{
s.exprname = nodbool(s.kind == switchKindTrue) exprname: cond,
} else if consttype(cond) > 0 {
// leave constants to enable dead code elimination (issue 9608)
s.exprname = cond
} else {
s.exprname = temp(cond.Type)
cas = []*Node{nod(OAS, s.exprname, cond)} // This gets walk()ed again in walkstmtlist just before end of this function. See #29562.
typecheckslice(cas, ctxStmt)
} }
// Enumerate the cases and prepare the default case. br := nod(OBREAK, nil, nil)
clauses := s.genCaseClauses(sw.List.Slice()) var defaultGoto *Node
sw.List.Set(nil) var body Nodes
cc := clauses.list for _, ncase := range sw.List.Slice() {
label := autolabel(".s")
jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label))
// Process case dispatch.
if ncase.List.Len() == 0 {
if defaultGoto != nil {
Fatalf("duplicate default case not detected during typechecking")
}
defaultGoto = jmp
}
for _, n1 := range ncase.List.Slice() {
s.Add(ncase.Pos, n1, jmp)
}
// handle the cases in order // Process body.
for len(cc) > 0 { body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label)))
run := 1 body.Append(ncase.Nbody.Slice()...)
if okforcmp[t.Etype] && cc[0].isconst { if !hasFall(ncase.Nbody.Slice()) {
// do binary search on runs of constants body.Append(br)
for ; run < len(cc) && cc[run].isconst; run++ {
} }
// sort and compile constants
sort.Sort(caseClauseByConstVal(cc[:run]))
} }
sw.List.Set(nil)
a := s.walkCases(cc[:run]) if defaultGoto == nil {
cas = append(cas, a) defaultGoto = br
cc = cc[run:]
} }
// handle default case s.Emit(&sw.Nbody)
if nerrors == 0 { sw.Nbody.Append(defaultGoto)
cas = append(cas, clauses.defjmp) sw.Nbody.AppendNodes(&body)
sw.Nbody.Prepend(cas...)
walkstmtlist(sw.Nbody.Slice()) walkstmtlist(sw.Nbody.Slice())
}
} }
// walkCases generates an AST implementing the cases in cc. // An exprSwitch walks an expression switch.
func (s *exprSwitch) walkCases(cc []caseClause) *Node { type exprSwitch struct {
if len(cc) < binarySearchMin { exprname *Node // value being switched on
// linear search
var cas []*Node
for _, c := range cc {
n := c.node
lno := setlineno(n)
a := nod(OIF, nil, nil)
if rng := n.List.Slice(); rng != nil {
// Integer range.
// exprname is a temp or a constant,
// so it is safe to evaluate twice.
// In most cases, this conjunction will be
// rewritten by walkinrange into a single comparison.
low := nod(OGE, s.exprname, rng[0])
high := nod(OLE, s.exprname, rng[1])
a.Left = nod(OANDAND, low, high)
} else if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE {
a.Left = nod(OEQ, s.exprname, n.Left) // if name == val
} else if s.kind == switchKindTrue {
a.Left = n.Left // if val
} else {
// s.kind == switchKindFalse
a.Left = nod(ONOT, n.Left, nil) // if !val
}
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(n.Right) // goto l
cas = append(cas, a) done Nodes
lineno = lno clauses []exprClause
} }
return liststmt(cas)
}
// find the middle and recur type exprClause struct {
half := len(cc) / 2 pos src.XPos
a := nod(OIF, nil, nil) lo, hi *Node
n := cc[half-1].node jmp *Node
var mid *Node
if rng := n.List.Slice(); rng != nil {
mid = rng[1] // high end of range
} else {
mid = n.Left
}
le := nod(OLE, s.exprname, mid)
if Isconst(mid, CTSTR) {
// Search by length and then by value; see caseClauseByConstVal.
lenlt := nod(OLT, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
leneq := nod(OEQ, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
a.Left = nod(OOROR, lenlt, nod(OANDAND, leneq, le))
} else {
a.Left = le
}
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(s.walkCases(cc[:half]))
a.Rlist.Set1(s.walkCases(cc[half:]))
return a
} }
// casebody builds separate lists of statements and cases. func (s *exprSwitch) Add(pos src.XPos, expr, jmp *Node) {
// It makes labels between cases and statements c := exprClause{pos: pos, lo: expr, hi: expr, jmp: jmp}
// and deals with fallthrough, break, and unreachable statements. if okforcmp[s.exprname.Type.Etype] && expr.Op == OLITERAL {
func casebody(sw *Node, typeswvar *Node) { s.clauses = append(s.clauses, c)
if sw.List.Len() == 0 {
return return
} }
lno := setlineno(sw) s.flush()
s.clauses = append(s.clauses, c)
s.flush()
}
var cas []*Node // cases func (s *exprSwitch) Emit(out *Nodes) {
var stat []*Node // statements s.flush()
var def *Node // defaults out.AppendNodes(&s.done)
br := nod(OBREAK, nil, nil) }
for _, n := range sw.List.Slice() { func (s *exprSwitch) flush() {
setlineno(n) cc := s.clauses
if n.Op != OXCASE { s.clauses = nil
Fatalf("casebody %v", n.Op) if len(cc) == 0 {
} return
n.Op = OCASE
needvar := n.List.Len() != 1 || n.List.First().Op == OLITERAL
lbl := autolabel(".s")
jmp := nodSym(OGOTO, nil, lbl)
switch n.List.Len() {
case 0:
// default
if def != nil {
yyerrorl(n.Pos, "more than one default case")
}
// reuse original default case
n.Right = jmp
def = n
case 1:
// one case -- reuse OCASE node
n.Left = n.List.First()
n.Right = jmp
n.List.Set(nil)
cas = append(cas, n)
default:
// Expand multi-valued cases and detect ranges of integer cases.
if typeswvar != nil || sw.Left.Type.IsInterface() || !n.List.First().Type.IsInteger() || n.List.Len() < integerRangeMin {
// Can't use integer ranges. Expand each case into a separate node.
for _, n1 := range n.List.Slice() {
cas = append(cas, nod(OCASE, n1, jmp))
}
break
}
// Find integer ranges within runs of constants.
s := n.List.Slice()
j := 0
for j < len(s) {
// Find a run of constants.
var run int
for run = j; run < len(s) && Isconst(s[run], CTINT); run++ {
}
if run-j >= integerRangeMin {
// Search for integer ranges in s[j:run].
// Typechecking is done, so all values are already in an appropriate range.
search := s[j:run]
sort.Sort(constIntNodesByVal(search))
for beg, end := 0, 1; end <= len(search); end++ {
if end < len(search) && search[end].Int64() == search[end-1].Int64()+1 {
continue
}
if end-beg >= integerRangeMin {
// Record range in List.
c := nod(OCASE, nil, jmp)
c.List.Set2(search[beg], search[end-1])
cas = append(cas, c)
} else {
// Not large enough for range; record separately.
for _, n := range search[beg:end] {
cas = append(cas, nod(OCASE, n, jmp))
}
}
beg = end
}
j = run
}
// Advance to next constant, adding individual non-constant
// or as-yet-unhandled constant cases as we go.
for ; j < len(s) && (j < run || !Isconst(s[j], CTINT)); j++ {
cas = append(cas, nod(OCASE, s[j], jmp))
}
}
} }
stat = append(stat, nodSym(OLABEL, nil, lbl)) // Caution: If len(cc) == 1, then cc[0] might not an OLITERAL.
if typeswvar != nil && needvar && n.Rlist.Len() != 0 { // The code below is structured to implicitly handle this case
l := []*Node{ // (e.g., sort.Slice doesn't need to invoke the less function
nod(ODCL, n.Rlist.First(), nil), // when there's only a single slice element).
nod(OAS, n.Rlist.First(), typeswvar),
}
typecheckslice(l, ctxStmt)
stat = append(stat, l...)
}
stat = append(stat, n.Nbody.Slice()...)
// Search backwards for the index of the fallthrough // Sort strings by length and then by value.
// statement. Do not assume it'll be in the last // It is much cheaper to compare lengths than values,
// position, since in some cases (e.g. when the statement // and all we need here is consistency.
// list contains autotmp_ variables), one or more OVARKILL // We respect this sorting below.
// nodes will be at the end of the list. sort.Slice(cc, func(i, j int) bool {
fallIndex := len(stat) - 1 vi := cc[i].lo.Val()
for stat[fallIndex].Op == OVARKILL { vj := cc[j].lo.Val()
fallIndex--
if s.exprname.Type.IsString() {
si := vi.U.(string)
sj := vj.U.(string)
if len(si) != len(sj) {
return len(si) < len(sj)
}
return si < sj
}
return compareOp(vi, OLT, vj)
})
// Merge consecutive integer cases.
if s.exprname.Type.IsInteger() {
merged := cc[:1]
for _, c := range cc[1:] {
last := &merged[len(merged)-1]
if last.jmp == c.jmp && last.hi.Int64()+1 == c.lo.Int64() {
last.hi = c.lo
} else {
merged = append(merged, c)
} }
last := stat[fallIndex]
if last.Op != OFALL {
stat = append(stat, br)
} }
cc = merged
} }
stat = append(stat, br) binarySearch(len(cc), &s.done,
if def != nil { func(i int) *Node {
cas = append(cas, def) mid := cc[i-1].hi
}
sw.List.Set(cas) le := nod(OLE, s.exprname, mid)
sw.Nbody.Set(stat) if s.exprname.Type.IsString() {
lineno = lno // Compare strings by length and then
// by value; see sort.Slice above.
lenlt := nod(OLT, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
leneq := nod(OEQ, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
le = nod(OOROR, lenlt, nod(OANDAND, leneq, le))
}
return le
},
func(i int, out *Nodes) {
c := &cc[i]
nif := nodl(c.pos, OIF, c.test(s.exprname), nil)
nif.Left = typecheck(nif.Left, ctxExpr)
nif.Left = defaultlit(nif.Left, nil)
nif.Nbody.Set1(c.jmp)
out.Append(nif)
},
)
} }
// genCaseClauses generates the caseClauses value for clauses. func (c *exprClause) test(exprname *Node) *Node {
func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses { // Integer range.
var cc caseClauses if c.hi != c.lo {
for _, n := range clauses { low := nodl(c.pos, OGE, exprname, c.lo)
if n.Left == nil && n.List.Len() == 0 { high := nodl(c.pos, OLE, exprname, c.hi)
// default case return nodl(c.pos, OANDAND, low, high)
if cc.defjmp != nil {
Fatalf("duplicate default case not detected during typechecking")
}
cc.defjmp = n.Right
continue
}
c := caseClause{node: n, ordinal: len(cc.list)}
if n.List.Len() > 0 {
c.isconst = true
} }
switch consttype(n.Left) {
case CTFLT, CTINT, CTRUNE, CTSTR: // Optimize "switch true { ...}" and "switch false { ... }".
c.isconst = true if Isconst(exprname, CTBOOL) && !c.lo.Type.IsInterface() {
if exprname.Val().U.(bool) {
return c.lo
} else {
return nodl(c.pos, ONOT, c.lo, nil)
} }
cc.list = append(cc.list, c)
} }
if cc.defjmp == nil { return nodl(c.pos, OEQ, exprname, c.lo)
cc.defjmp = nod(OBREAK, nil, nil)
}
return cc
} }
// genCaseClauses generates the caseClauses value for clauses. func allCaseExprsAreSideEffectFree(sw *Node) bool {
func (s *typeSwitch) genCaseClauses(clauses []*Node) caseClauses { // In theory, we could be more aggressive, allowing any
var cc caseClauses // side-effect-free expressions in cases, but it's a bit
for _, n := range clauses { // tricky because some of that information is unavailable due
switch { // to the introduction of temporaries during order.
case n.Left == nil: // Restricting to constants is simple and probably powerful
// default case // enough.
if cc.defjmp != nil {
Fatalf("duplicate default case not detected during typechecking")
}
cc.defjmp = n.Right
continue
case n.Left.Op == OLITERAL:
// nil case in type switch
if cc.niljmp != nil {
Fatalf("duplicate nil case not detected during typechecking")
}
cc.niljmp = n.Right
continue
}
// general case for _, ncase := range sw.List.Slice() {
c := caseClause{ if ncase.Op != OXCASE {
node: n, Fatalf("switch string(byteslice) bad op: %v", ncase.Op)
ordinal: len(cc.list),
isconst: !n.Left.Type.IsInterface(),
hash: typehash(n.Left.Type),
} }
cc.list = append(cc.list, c) for _, v := range ncase.List.Slice() {
if v.Op != OLITERAL {
return false
} }
if cc.defjmp == nil {
cc.defjmp = nod(OBREAK, nil, nil)
} }
}
return cc return true
} }
// walk generates an AST that implements sw, // hasFall reports whether stmts ends with a "fallthrough" statement.
// where sw is a type switch. func hasFall(stmts []*Node) bool {
// The AST is generally of the form of a linear // Search backwards for the index of the fallthrough
// search using if..goto, although binary search // statement. Do not assume it'll be in the last
// is used with long runs of concrete types. // position, since in some cases (e.g. when the statement
func (s *typeSwitch) walk(sw *Node) { // list contains autotmp_ variables), one or more OVARKILL
cond := sw.Left // nodes will be at the end of the list.
sw.Left = nil
if cond == nil {
sw.List.Set(nil)
return
}
if cond.Right == nil {
yyerrorl(sw.Pos, "type switch must have an assignment")
return
}
cond.Right = walkexpr(cond.Right, &sw.Ninit) i := len(stmts) - 1
if !cond.Right.Type.IsInterface() { for i >= 0 && stmts[i].Op == OVARKILL {
yyerrorl(sw.Pos, "type switch must be on an interface") i--
return
} }
return i >= 0 && stmts[i].Op == OFALL
}
var cas []*Node // walkTypeSwitch generates an AST that implements sw, where sw is a
// type switch.
// predeclare temporary variables and the boolean var func walkTypeSwitch(sw *Node) {
s.facename = temp(cond.Right.Type) var s typeSwitch
s.facename = sw.Left.Right
a := nod(OAS, s.facename, cond.Right) sw.Left = nil
a = typecheck(a, ctxStmt)
cas = append(cas, a)
s.facename = walkexpr(s.facename, &sw.Ninit)
s.facename = copyexpr(s.facename, s.facename.Type, &sw.Nbody)
s.okname = temp(types.Types[TBOOL]) s.okname = temp(types.Types[TBOOL])
s.okname = typecheck(s.okname, ctxExpr)
s.hashname = temp(types.Types[TUINT32])
s.hashname = typecheck(s.hashname, ctxExpr)
// set up labels and jumps // Get interface descriptor word.
casebody(sw, s.facename) // For empty interfaces this will be the type.
// For non-empty interfaces this will be the itab.
clauses := s.genCaseClauses(sw.List.Slice()) itab := nod(OITAB, s.facename, nil)
sw.List.Set(nil)
def := clauses.defjmp
// For empty interfaces, do: // For empty interfaces, do:
// if e._type == nil { // if e._type == nil {
...@@ -688,230 +487,235 @@ func (s *typeSwitch) walk(sw *Node) { ...@@ -688,230 +487,235 @@ func (s *typeSwitch) walk(sw *Node) {
// } // }
// h := e._type.hash // h := e._type.hash
// Use a similar strategy for non-empty interfaces. // Use a similar strategy for non-empty interfaces.
ifNil := nod(OIF, nil, nil)
ifNil.Left = nod(OEQ, itab, nodnil())
ifNil.Left = typecheck(ifNil.Left, ctxExpr)
ifNil.Left = defaultlit(ifNil.Left, nil)
// ifNil.Nbody assigned at end.
sw.Nbody.Append(ifNil)
// Get interface descriptor word. // Load hash from type or itab.
// For empty interfaces this will be the type. dotHash := nodSym(ODOTPTR, itab, nil)
// For non-empty interfaces this will be the itab. dotHash.Type = types.Types[TUINT32]
itab := nod(OITAB, s.facename, nil) dotHash.SetTypecheck(1)
if s.facename.Type.IsEmptyInterface() {
// Check for nil first. dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime._type
i := nod(OIF, nil, nil)
i.Left = nod(OEQ, itab, nodnil())
if clauses.niljmp != nil {
// Do explicit nil case right here.
i.Nbody.Set1(clauses.niljmp)
} else { } else {
// Jump to default case. dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime.itab
lbl := autolabel(".s")
i.Nbody.Set1(nodSym(OGOTO, nil, lbl))
// Wrap default case with label.
blk := nod(OBLOCK, nil, nil)
blk.List.Set2(nodSym(OLABEL, nil, lbl), def)
def = blk
} }
i.Left = typecheck(i.Left, ctxExpr) dotHash.SetBounded(true) // guaranteed not to fault
i.Left = defaultlit(i.Left, nil) s.hashname = copyexpr(dotHash, dotHash.Type, &sw.Nbody)
cas = append(cas, i)
// Load hash from type or itab. br := nod(OBREAK, nil, nil)
h := nodSym(ODOTPTR, itab, nil) var defaultGoto, nilGoto *Node
h.Type = types.Types[TUINT32] var body Nodes
h.SetTypecheck(1) for _, ncase := range sw.List.Slice() {
if cond.Right.Type.IsEmptyInterface() { var caseVar *Node
h.Xoffset = int64(2 * Widthptr) // offset of hash in runtime._type if ncase.Rlist.Len() != 0 {
} else { caseVar = ncase.Rlist.First()
h.Xoffset = int64(2 * Widthptr) // offset of hash in runtime.itab
} }
h.SetBounded(true) // guaranteed not to fault
a = nod(OAS, s.hashname, h)
a = typecheck(a, ctxStmt)
cas = append(cas, a)
cc := clauses.list // For single-type cases, we initialize the case
// variable as part of the type assertion; but in
// other cases, we initialize it in the body.
singleType := ncase.List.Len() == 1 && ncase.List.First().Op == OTYPE
// insert type equality check into each case block label := autolabel(".s")
for _, c := range cc {
c.node.Right = s.typeone(c.node) jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label))
if ncase.List.Len() == 0 { // default:
if defaultGoto != nil {
Fatalf("duplicate default case not detected during typechecking")
}
defaultGoto = jmp
} }
// generate list of if statements, binary search for constant sequences for _, n1 := range ncase.List.Slice() {
for len(cc) > 0 { if n1.isNil() { // case nil:
if !cc[0].isconst { if nilGoto != nil {
n := cc[0].node Fatalf("duplicate nil case not detected during typechecking")
cas = append(cas, n.Right) }
cc = cc[1:] nilGoto = jmp
continue continue
} }
// identify run of constants if singleType {
var run int s.Add(n1.Type, caseVar, jmp)
for run = 1; run < len(cc) && cc[run].isconst; run++ { } else {
s.Add(n1.Type, nil, jmp)
}
} }
// sort by hash body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label)))
sort.Sort(caseClauseByType(cc[:run])) if caseVar != nil && !singleType {
l := []*Node{
// for debugging: linear search nodl(ncase.Pos, ODCL, caseVar, nil),
if false { nodl(ncase.Pos, OAS, caseVar, s.facename),
for i := 0; i < run; i++ {
n := cc[i].node
cas = append(cas, n.Right)
} }
continue typecheckslice(l, ctxStmt)
body.Append(l...)
} }
body.Append(ncase.Nbody.Slice()...)
// combine adjacent cases with the same hash body.Append(br)
var batch []caseClause
for i, j := 0, 0; i < run; i = j {
hash := []*Node{cc[i].node.Right}
for j = i + 1; j < run && cc[i].hash == cc[j].hash; j++ {
hash = append(hash, cc[j].node.Right)
} }
cc[i].node.Right = liststmt(hash) sw.List.Set(nil)
batch = append(batch, cc[i])
if defaultGoto == nil {
defaultGoto = br
} }
// binary search among cases to narrow by hash if nilGoto != nil {
cas = append(cas, s.walkCases(batch)) ifNil.Nbody.Set1(nilGoto)
cc = cc[run:] } else {
// TODO(mdempsky): Just use defaultGoto directly.
// Jump to default case.
label := autolabel(".s")
ifNil.Nbody.Set1(nodSym(OGOTO, nil, label))
// Wrap default case with label.
blk := nod(OBLOCK, nil, nil)
blk.List.Set2(nodSym(OLABEL, nil, label), defaultGoto)
defaultGoto = blk
} }
// handle default case s.Emit(&sw.Nbody)
if nerrors == 0 { sw.Nbody.Append(defaultGoto)
cas = append(cas, def) sw.Nbody.AppendNodes(&body)
sw.Nbody.Prepend(cas...)
sw.List.Set(nil)
walkstmtlist(sw.Nbody.Slice()) walkstmtlist(sw.Nbody.Slice())
}
} }
// typeone generates an AST that jumps to the // A typeSwitch walks a type switch.
// case body if the variable is of type t. type typeSwitch struct {
func (s *typeSwitch) typeone(t *Node) *Node { // Temporary variables (i.e., ONAMEs) used by type switch dispatch logic:
var name *Node facename *Node // value being type-switched on
var init Nodes hashname *Node // type hash of the value being type-switched on
if t.Rlist.Len() == 0 { okname *Node // boolean used for comma-ok type assertions
name = nblank
nblank = typecheck(nblank, ctxExpr|ctxAssign) done Nodes
} else { clauses []typeClause
name = t.Rlist.First()
init.Append(nod(ODCL, name, nil))
a := nod(OAS, name, nil)
a = typecheck(a, ctxStmt)
init.Append(a)
}
a := nod(OAS2, nil, nil)
a.List.Set2(name, s.okname) // name, ok =
b := nod(ODOTTYPE, s.facename, nil)
b.Type = t.Left.Type // interface.(type)
a.Rlist.Set1(b)
a = typecheck(a, ctxStmt)
a = walkexpr(a, &init)
init.Append(a)
c := nod(OIF, nil, nil)
c.Left = s.okname
c.Nbody.Set1(t.Right) // if ok { goto l }
init.Append(c)
return init.asblock()
} }
// walkCases generates an AST implementing the cases in cc. type typeClause struct {
func (s *typeSwitch) walkCases(cc []caseClause) *Node { hash uint32
if len(cc) < binarySearchMin { body Nodes
var cas []*Node }
for _, c := range cc {
n := c.node func (s *typeSwitch) Add(typ *types.Type, caseVar *Node, jmp *Node) {
if !c.isconst { var body Nodes
Fatalf("typeSwitch walkCases") if caseVar != nil {
} l := []*Node{
a := nod(OIF, nil, nil) nod(ODCL, caseVar, nil),
a.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash))) nod(OAS, caseVar, nil),
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(n.Right)
cas = append(cas, a)
} }
return liststmt(cas) typecheckslice(l, ctxStmt)
body.Append(l...)
} else {
caseVar = nblank
}
// cv, ok = iface.(type)
as := nod(OAS2, nil, nil)
as.List.Set2(caseVar, s.okname) // cv, ok =
dot := nod(ODOTTYPE, s.facename, nil)
dot.Type = typ // iface.(type)
as.Rlist.Set1(dot)
as = typecheck(as, ctxStmt)
as = walkexpr(as, &body)
body.Append(as)
// if ok { goto label }
nif := nod(OIF, nil, nil)
nif.Left = s.okname
nif.Nbody.Set1(jmp)
body.Append(nif)
if !typ.IsInterface() {
s.clauses = append(s.clauses, typeClause{
hash: typehash(typ),
body: body,
})
return
} }
// find the middle and recur s.flush()
half := len(cc) / 2 s.done.AppendNodes(&body)
a := nod(OIF, nil, nil)
a.Left = nod(OLE, s.hashname, nodintconst(int64(cc[half-1].hash)))
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(s.walkCases(cc[:half]))
a.Rlist.Set1(s.walkCases(cc[half:]))
return a
} }
// caseClauseByConstVal sorts clauses by constant value to enable binary search. func (s *typeSwitch) Emit(out *Nodes) {
type caseClauseByConstVal []caseClause s.flush()
out.AppendNodes(&s.done)
func (x caseClauseByConstVal) Len() int { return len(x) } }
func (x caseClauseByConstVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x caseClauseByConstVal) Less(i, j int) bool {
// n1 and n2 might be individual constants or integer ranges.
// We have checked for duplicates already,
// so ranges can be safely represented by any value in the range.
n1 := x[i].node
var v1 interface{}
if s := n1.List.Slice(); s != nil {
v1 = s[0].Val().U
} else {
v1 = n1.Left.Val().U
}
n2 := x[j].node func (s *typeSwitch) flush() {
var v2 interface{} cc := s.clauses
if s := n2.List.Slice(); s != nil { s.clauses = nil
v2 = s[0].Val().U if len(cc) == 0 {
} else { return
v2 = n2.Left.Val().U
} }
switch v1 := v1.(type) { sort.Slice(cc, func(i, j int) bool { return cc[i].hash < cc[j].hash })
case *Mpflt:
return v1.Cmp(v2.(*Mpflt)) < 0 // Combine adjacent cases with the same hash.
case *Mpint: merged := cc[:1]
return v1.Cmp(v2.(*Mpint)) < 0 for _, c := range cc[1:] {
case string: last := &merged[len(merged)-1]
// Sort strings by length and then by value. if last.hash == c.hash {
// It is much cheaper to compare lengths than values, last.body.AppendNodes(&c.body)
// and all we need here is consistency. } else {
// We respect this sorting in exprSwitch.walkCases. merged = append(merged, c)
a := v1
b := v2.(string)
if len(a) != len(b) {
return len(a) < len(b)
} }
return a < b
} }
cc = merged
Fatalf("caseClauseByConstVal passed bad clauses %v < %v", x[i].node.Left, x[j].node.Left) binarySearch(len(cc), &s.done,
return false func(i int) *Node {
return nod(OLE, s.hashname, nodintconst(int64(cc[i-1].hash)))
},
func(i int, out *Nodes) {
// TODO(mdempsky): Omit hash equality check if
// there's only one type.
c := cc[i]
a := nod(OIF, nil, nil)
a.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash)))
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.AppendNodes(&c.body)
out.Append(a)
},
)
} }
type caseClauseByType []caseClause // binarySearch constructs a binary search tree for handling n cases,
// and appends it to out. It's used for efficiently implementing
func (x caseClauseByType) Len() int { return len(x) } // switch statements.
func (x caseClauseByType) Swap(i, j int) { x[i], x[j] = x[j], x[i] } //
func (x caseClauseByType) Less(i, j int) bool { // less(i) should return a boolean expression. If it evaluates true,
c1, c2 := x[i], x[j] // then cases [0, i) will be tested; otherwise, cases [i, n).
// sort by hash code, then ordinal (for the rare case of hash collisions) //
if c1.hash != c2.hash { // base(i, out) should append statements to out to test the i'th case.
return c1.hash < c2.hash func binarySearch(n int, out *Nodes, less func(i int) *Node, base func(i int, out *Nodes)) {
const binarySearchMin = 4 // minimum number of cases for binary search
var do func(lo, hi int, out *Nodes)
do = func(lo, hi int, out *Nodes) {
n := hi - lo
if n < binarySearchMin {
for i := lo; i < hi; i++ {
base(i, out)
}
return
} }
return c1.ordinal < c2.ordinal
}
type constIntNodesByVal []*Node half := lo + n/2
nif := nod(OIF, nil, nil)
nif.Left = less(half)
nif.Left = typecheck(nif.Left, ctxExpr)
nif.Left = defaultlit(nif.Left, nil)
do(lo, half, &nif.Nbody)
do(half, hi, &nif.Rlist)
out.Append(nif)
}
func (x constIntNodesByVal) Len() int { return len(x) } do(0, n, out)
func (x constIntNodesByVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x constIntNodesByVal) Less(i, j int) bool {
return x[i].Val().U.(*Mpint).Cmp(x[j].Val().U.(*Mpint)) < 0
} }
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gc
import (
"testing"
)
func nodrune(r rune) *Node {
v := new(Mpint)
v.SetInt64(int64(r))
v.Rune = true
return nodlit(Val{v})
}
func nodflt(f float64) *Node {
v := newMpflt()
v.SetFloat64(f)
return nodlit(Val{v})
}
func TestCaseClauseByConstVal(t *testing.T) {
tests := []struct {
a, b *Node
}{
// CTFLT
{nodflt(0.1), nodflt(0.2)},
// CTINT
{nodintconst(0), nodintconst(1)},
// CTRUNE
{nodrune('a'), nodrune('b')},
// CTSTR
{nodlit(Val{"ab"}), nodlit(Val{"abc"})},
{nodlit(Val{"ab"}), nodlit(Val{"xyz"})},
{nodlit(Val{"abc"}), nodlit(Val{"xyz"})},
}
for i, test := range tests {
a := caseClause{node: nod(OXXX, test.a, nil)}
b := caseClause{node: nod(OXXX, test.b, nil)}
s := caseClauseByConstVal{a, b}
if less := s.Less(0, 1); !less {
t.Errorf("%d: caseClauseByConstVal(%v, %v) = false", i, test.a, test.b)
}
if less := s.Less(1, 0); less {
t.Errorf("%d: caseClauseByConstVal(%v, %v) = true", i, test.a, test.b)
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment