Commit 8f3d9855 authored by Matthew Dempsky's avatar Matthew Dempsky

cmd/compile: major refactoring of switch walking

There are a lot of complexities to handling switches efficiently:

1. Order matters for expression switches with non-constant cases and
for type expressions with interface types. We have to respect
side-effects, and we also can't allow later cases to accidentally take
precedence over earlier cases.

2. For runs of integers, floats, and string constants in expression
switches or runs of concrete types in type switches, we want to emit
efficient binary searches.

3. For runs of consecutive integers in expression switches, we want to
collapse them into range comparisons.

4. For binary searches of strings, we want to compare by length first,
because that's more efficient and we don't need to respect any
particular ordering.

5. For "switch true { ... }" and "switch false { ... }", we want to
optimize "case x:" as simply "if x" or "if !x", respectively, unless x
is interface-typed.

The current swt.go code reflects how these constraints have been
incrementally added over time, with each of them being handled ad
hocly in different parts of the code. Also, the existing code tries
very hard to reuse logic between expression and type switches, even
though the similarities are very superficial.

This CL rewrites switch handling to better abstract away the logic
involved in constructing the binary searches. In particular, it's
intended to make further optimizations to switch dispatch much easier.

It also eliminates the need for both OXCASE and OCASE ops, and a
subsequent CL can collapse the two.

Passes toolstash-check.

Change-Id: Ifcd1e56f81f858117a412971d82e98abe7c4481f
Reviewed-on: https://go-review.googlesource.com/c/go/+/194660
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarKeith Randall <khr@golang.org>
parent 115e4c9c
......@@ -10,50 +10,6 @@ import (
"sort"
)
const (
// expression switch
switchKindExpr = iota // switch a {...} or switch 5 {...}
switchKindTrue // switch true {...} or switch {...}
switchKindFalse // switch false {...}
)
const (
binarySearchMin = 4 // minimum number of cases for binary search
integerRangeMin = 2 // minimum size of integer ranges
)
// An exprSwitch walks an expression switch.
type exprSwitch struct {
exprname *Node // node for the expression being switched on
kind int // kind of switch statement (switchKind*)
}
// A typeSwitch walks a type switch.
type typeSwitch struct {
hashname *Node // node for the hash of the type of the variable being switched on
facename *Node // node for the concrete type of the variable being switched on
okname *Node // boolean node used for comma-ok type assertions
}
// A caseClause is a single case clause in a switch statement.
type caseClause struct {
node *Node // points at case statement
ordinal int // position in switch
hash uint32 // hash of a type switch
// isconst indicates whether this case clause is a constant,
// for the purposes of the switch code generation.
// For expression switches, that's generally literals (case 5:, not case x:).
// For type switches, that's concrete types (case time.Time:), not interfaces (case io.Reader:).
isconst bool
}
// caseClauses are all the case clauses in a switch statement.
type caseClauses struct {
list []caseClause // general cases
defjmp *Node // OGOTO for default case or OBREAK if no default case present
niljmp *Node // OGOTO for nil type case in a type switch
}
// typecheckswitch typechecks a switch statement.
func typecheckswitch(n *Node) {
typecheckslice(n.Ninit.Slice(), ctxStmt)
......@@ -71,7 +27,6 @@ func typecheckTypeSwitch(n *Node) {
yyerrorl(n.Pos, "cannot type switch on non-interface value %L", n.Left.Right)
t = nil
}
n.Type = t // TODO(mdempsky): Remove; statements aren't typed.
// We don't actually declare the type switch's guarded
// declaration itself. So if there are no cases, we won't
......@@ -212,7 +167,6 @@ func typecheckExprSwitch(n *Node) {
t = nil
}
}
n.Type = t // TODO(mdempsky): Remove; statements aren't typed.
var defCase *Node
var cs constSet
......@@ -265,422 +219,267 @@ func typecheckExprSwitch(n *Node) {
// walkswitch walks a switch statement.
func walkswitch(sw *Node) {
// convert switch {...} to switch true {...}
if sw.Left == nil {
sw.Left = nodbool(true)
sw.Left = typecheck(sw.Left, ctxExpr)
sw.Left = defaultlit(sw.Left, nil)
// Guard against double walk, see #25776.
if sw.List.Len() == 0 && sw.Nbody.Len() > 0 {
return // Was fatal, but eliminating every possible source of double-walking is hard
}
if sw.Left.Op == OTYPESW {
var s typeSwitch
s.walk(sw)
if sw.Left != nil && sw.Left.Op == OTYPESW {
walkTypeSwitch(sw)
} else {
var s exprSwitch
s.walk(sw)
walkExprSwitch(sw)
}
}
// walk generates an AST implementing sw.
// sw is an expression switch.
// The AST is generally of the form of a linear
// search using if..goto, although binary search
// is used with long runs of constants.
func (s *exprSwitch) walk(sw *Node) {
// Guard against double walk, see #25776.
if sw.List.Len() == 0 && sw.Nbody.Len() > 0 {
return // Was fatal, but eliminating every possible source of double-walking is hard
}
casebody(sw, nil)
// walkExprSwitch generates an AST implementing sw. sw is an
// expression switch.
func walkExprSwitch(sw *Node) {
lno := setlineno(sw)
cond := sw.Left
sw.Left = nil
s.kind = switchKindExpr
if Isconst(cond, CTBOOL) {
s.kind = switchKindTrue
if !cond.Val().U.(bool) {
s.kind = switchKindFalse
}
// convert switch {...} to switch true {...}
if cond == nil {
cond = nodbool(true)
cond = typecheck(cond, ctxExpr)
cond = defaultlit(cond, nil)
}
// Given "switch string(byteslice)",
// with all cases being constants (or the default case),
// with all cases being side-effect free,
// use a zero-cost alias of the byte slice.
// In theory, we could be more aggressive,
// allowing any side-effect-free expressions in cases,
// but it's a bit tricky because some of that information
// is unavailable due to the introduction of temporaries during order.
// Restricting to constants is simple and probably powerful enough.
// Do this before calling walkexpr on cond,
// because walkexpr will lower the string
// conversion into a runtime call.
// See issue 24937 for more discussion.
if cond.Op == OBYTES2STR {
ok := true
for _, cas := range sw.List.Slice() {
if cas.Op != OCASE {
Fatalf("switch string(byteslice) bad op: %v", cas.Op)
}
if cas.Left != nil && !Isconst(cas.Left, CTSTR) {
ok = false
break
}
}
if ok {
if cond.Op == OBYTES2STR && allCaseExprsAreSideEffectFree(sw) {
cond.Op = OBYTES2STRTMP
}
}
cond = walkexpr(cond, &sw.Ninit)
t := sw.Type
if t == nil {
return
if cond.Op != OLITERAL {
cond = copyexpr(cond, cond.Type, &sw.Nbody)
}
// convert the switch into OIF statements
var cas []*Node
if s.kind == switchKindTrue || s.kind == switchKindFalse {
s.exprname = nodbool(s.kind == switchKindTrue)
} else if consttype(cond) > 0 {
// leave constants to enable dead code elimination (issue 9608)
s.exprname = cond
} else {
s.exprname = temp(cond.Type)
cas = []*Node{nod(OAS, s.exprname, cond)} // This gets walk()ed again in walkstmtlist just before end of this function. See #29562.
typecheckslice(cas, ctxStmt)
lineno = lno
s := exprSwitch{
exprname: cond,
}
// Enumerate the cases and prepare the default case.
clauses := s.genCaseClauses(sw.List.Slice())
sw.List.Set(nil)
cc := clauses.list
br := nod(OBREAK, nil, nil)
var defaultGoto *Node
var body Nodes
for _, ncase := range sw.List.Slice() {
label := autolabel(".s")
jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label))
// Process case dispatch.
if ncase.List.Len() == 0 {
if defaultGoto != nil {
Fatalf("duplicate default case not detected during typechecking")
}
defaultGoto = jmp
}
for _, n1 := range ncase.List.Slice() {
s.Add(ncase.Pos, n1, jmp)
}
// handle the cases in order
for len(cc) > 0 {
run := 1
if okforcmp[t.Etype] && cc[0].isconst {
// do binary search on runs of constants
for ; run < len(cc) && cc[run].isconst; run++ {
// Process body.
body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label)))
body.Append(ncase.Nbody.Slice()...)
if !hasFall(ncase.Nbody.Slice()) {
body.Append(br)
}
// sort and compile constants
sort.Sort(caseClauseByConstVal(cc[:run]))
}
sw.List.Set(nil)
a := s.walkCases(cc[:run])
cas = append(cas, a)
cc = cc[run:]
if defaultGoto == nil {
defaultGoto = br
}
// handle default case
if nerrors == 0 {
cas = append(cas, clauses.defjmp)
sw.Nbody.Prepend(cas...)
s.Emit(&sw.Nbody)
sw.Nbody.Append(defaultGoto)
sw.Nbody.AppendNodes(&body)
walkstmtlist(sw.Nbody.Slice())
}
}
// walkCases generates an AST implementing the cases in cc.
func (s *exprSwitch) walkCases(cc []caseClause) *Node {
if len(cc) < binarySearchMin {
// linear search
var cas []*Node
for _, c := range cc {
n := c.node
lno := setlineno(n)
a := nod(OIF, nil, nil)
if rng := n.List.Slice(); rng != nil {
// Integer range.
// exprname is a temp or a constant,
// so it is safe to evaluate twice.
// In most cases, this conjunction will be
// rewritten by walkinrange into a single comparison.
low := nod(OGE, s.exprname, rng[0])
high := nod(OLE, s.exprname, rng[1])
a.Left = nod(OANDAND, low, high)
} else if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE {
a.Left = nod(OEQ, s.exprname, n.Left) // if name == val
} else if s.kind == switchKindTrue {
a.Left = n.Left // if val
} else {
// s.kind == switchKindFalse
a.Left = nod(ONOT, n.Left, nil) // if !val
}
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(n.Right) // goto l
// An exprSwitch walks an expression switch.
type exprSwitch struct {
exprname *Node // value being switched on
cas = append(cas, a)
lineno = lno
}
return liststmt(cas)
}
done Nodes
clauses []exprClause
}
// find the middle and recur
half := len(cc) / 2
a := nod(OIF, nil, nil)
n := cc[half-1].node
var mid *Node
if rng := n.List.Slice(); rng != nil {
mid = rng[1] // high end of range
} else {
mid = n.Left
}
le := nod(OLE, s.exprname, mid)
if Isconst(mid, CTSTR) {
// Search by length and then by value; see caseClauseByConstVal.
lenlt := nod(OLT, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
leneq := nod(OEQ, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
a.Left = nod(OOROR, lenlt, nod(OANDAND, leneq, le))
} else {
a.Left = le
}
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(s.walkCases(cc[:half]))
a.Rlist.Set1(s.walkCases(cc[half:]))
return a
type exprClause struct {
pos src.XPos
lo, hi *Node
jmp *Node
}
// casebody builds separate lists of statements and cases.
// It makes labels between cases and statements
// and deals with fallthrough, break, and unreachable statements.
func casebody(sw *Node, typeswvar *Node) {
if sw.List.Len() == 0 {
func (s *exprSwitch) Add(pos src.XPos, expr, jmp *Node) {
c := exprClause{pos: pos, lo: expr, hi: expr, jmp: jmp}
if okforcmp[s.exprname.Type.Etype] && expr.Op == OLITERAL {
s.clauses = append(s.clauses, c)
return
}
lno := setlineno(sw)
s.flush()
s.clauses = append(s.clauses, c)
s.flush()
}
var cas []*Node // cases
var stat []*Node // statements
var def *Node // defaults
br := nod(OBREAK, nil, nil)
func (s *exprSwitch) Emit(out *Nodes) {
s.flush()
out.AppendNodes(&s.done)
}
for _, n := range sw.List.Slice() {
setlineno(n)
if n.Op != OXCASE {
Fatalf("casebody %v", n.Op)
}
n.Op = OCASE
needvar := n.List.Len() != 1 || n.List.First().Op == OLITERAL
lbl := autolabel(".s")
jmp := nodSym(OGOTO, nil, lbl)
switch n.List.Len() {
case 0:
// default
if def != nil {
yyerrorl(n.Pos, "more than one default case")
}
// reuse original default case
n.Right = jmp
def = n
case 1:
// one case -- reuse OCASE node
n.Left = n.List.First()
n.Right = jmp
n.List.Set(nil)
cas = append(cas, n)
default:
// Expand multi-valued cases and detect ranges of integer cases.
if typeswvar != nil || sw.Left.Type.IsInterface() || !n.List.First().Type.IsInteger() || n.List.Len() < integerRangeMin {
// Can't use integer ranges. Expand each case into a separate node.
for _, n1 := range n.List.Slice() {
cas = append(cas, nod(OCASE, n1, jmp))
}
break
}
// Find integer ranges within runs of constants.
s := n.List.Slice()
j := 0
for j < len(s) {
// Find a run of constants.
var run int
for run = j; run < len(s) && Isconst(s[run], CTINT); run++ {
}
if run-j >= integerRangeMin {
// Search for integer ranges in s[j:run].
// Typechecking is done, so all values are already in an appropriate range.
search := s[j:run]
sort.Sort(constIntNodesByVal(search))
for beg, end := 0, 1; end <= len(search); end++ {
if end < len(search) && search[end].Int64() == search[end-1].Int64()+1 {
continue
}
if end-beg >= integerRangeMin {
// Record range in List.
c := nod(OCASE, nil, jmp)
c.List.Set2(search[beg], search[end-1])
cas = append(cas, c)
} else {
// Not large enough for range; record separately.
for _, n := range search[beg:end] {
cas = append(cas, nod(OCASE, n, jmp))
}
}
beg = end
}
j = run
}
// Advance to next constant, adding individual non-constant
// or as-yet-unhandled constant cases as we go.
for ; j < len(s) && (j < run || !Isconst(s[j], CTINT)); j++ {
cas = append(cas, nod(OCASE, s[j], jmp))
}
}
func (s *exprSwitch) flush() {
cc := s.clauses
s.clauses = nil
if len(cc) == 0 {
return
}
stat = append(stat, nodSym(OLABEL, nil, lbl))
if typeswvar != nil && needvar && n.Rlist.Len() != 0 {
l := []*Node{
nod(ODCL, n.Rlist.First(), nil),
nod(OAS, n.Rlist.First(), typeswvar),
}
typecheckslice(l, ctxStmt)
stat = append(stat, l...)
}
stat = append(stat, n.Nbody.Slice()...)
// Caution: If len(cc) == 1, then cc[0] might not an OLITERAL.
// The code below is structured to implicitly handle this case
// (e.g., sort.Slice doesn't need to invoke the less function
// when there's only a single slice element).
// Search backwards for the index of the fallthrough
// statement. Do not assume it'll be in the last
// position, since in some cases (e.g. when the statement
// list contains autotmp_ variables), one or more OVARKILL
// nodes will be at the end of the list.
fallIndex := len(stat) - 1
for stat[fallIndex].Op == OVARKILL {
fallIndex--
// Sort strings by length and then by value.
// It is much cheaper to compare lengths than values,
// and all we need here is consistency.
// We respect this sorting below.
sort.Slice(cc, func(i, j int) bool {
vi := cc[i].lo.Val()
vj := cc[j].lo.Val()
if s.exprname.Type.IsString() {
si := vi.U.(string)
sj := vj.U.(string)
if len(si) != len(sj) {
return len(si) < len(sj)
}
return si < sj
}
return compareOp(vi, OLT, vj)
})
// Merge consecutive integer cases.
if s.exprname.Type.IsInteger() {
merged := cc[:1]
for _, c := range cc[1:] {
last := &merged[len(merged)-1]
if last.jmp == c.jmp && last.hi.Int64()+1 == c.lo.Int64() {
last.hi = c.lo
} else {
merged = append(merged, c)
}
last := stat[fallIndex]
if last.Op != OFALL {
stat = append(stat, br)
}
cc = merged
}
stat = append(stat, br)
if def != nil {
cas = append(cas, def)
}
binarySearch(len(cc), &s.done,
func(i int) *Node {
mid := cc[i-1].hi
sw.List.Set(cas)
sw.Nbody.Set(stat)
lineno = lno
le := nod(OLE, s.exprname, mid)
if s.exprname.Type.IsString() {
// Compare strings by length and then
// by value; see sort.Slice above.
lenlt := nod(OLT, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
leneq := nod(OEQ, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil))
le = nod(OOROR, lenlt, nod(OANDAND, leneq, le))
}
return le
},
func(i int, out *Nodes) {
c := &cc[i]
nif := nodl(c.pos, OIF, c.test(s.exprname), nil)
nif.Left = typecheck(nif.Left, ctxExpr)
nif.Left = defaultlit(nif.Left, nil)
nif.Nbody.Set1(c.jmp)
out.Append(nif)
},
)
}
// genCaseClauses generates the caseClauses value for clauses.
func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses {
var cc caseClauses
for _, n := range clauses {
if n.Left == nil && n.List.Len() == 0 {
// default case
if cc.defjmp != nil {
Fatalf("duplicate default case not detected during typechecking")
}
cc.defjmp = n.Right
continue
}
c := caseClause{node: n, ordinal: len(cc.list)}
if n.List.Len() > 0 {
c.isconst = true
func (c *exprClause) test(exprname *Node) *Node {
// Integer range.
if c.hi != c.lo {
low := nodl(c.pos, OGE, exprname, c.lo)
high := nodl(c.pos, OLE, exprname, c.hi)
return nodl(c.pos, OANDAND, low, high)
}
switch consttype(n.Left) {
case CTFLT, CTINT, CTRUNE, CTSTR:
c.isconst = true
// Optimize "switch true { ...}" and "switch false { ... }".
if Isconst(exprname, CTBOOL) && !c.lo.Type.IsInterface() {
if exprname.Val().U.(bool) {
return c.lo
} else {
return nodl(c.pos, ONOT, c.lo, nil)
}
cc.list = append(cc.list, c)
}
if cc.defjmp == nil {
cc.defjmp = nod(OBREAK, nil, nil)
}
return cc
return nodl(c.pos, OEQ, exprname, c.lo)
}
// genCaseClauses generates the caseClauses value for clauses.
func (s *typeSwitch) genCaseClauses(clauses []*Node) caseClauses {
var cc caseClauses
for _, n := range clauses {
switch {
case n.Left == nil:
// default case
if cc.defjmp != nil {
Fatalf("duplicate default case not detected during typechecking")
}
cc.defjmp = n.Right
continue
case n.Left.Op == OLITERAL:
// nil case in type switch
if cc.niljmp != nil {
Fatalf("duplicate nil case not detected during typechecking")
}
cc.niljmp = n.Right
continue
}
func allCaseExprsAreSideEffectFree(sw *Node) bool {
// In theory, we could be more aggressive, allowing any
// side-effect-free expressions in cases, but it's a bit
// tricky because some of that information is unavailable due
// to the introduction of temporaries during order.
// Restricting to constants is simple and probably powerful
// enough.
// general case
c := caseClause{
node: n,
ordinal: len(cc.list),
isconst: !n.Left.Type.IsInterface(),
hash: typehash(n.Left.Type),
for _, ncase := range sw.List.Slice() {
if ncase.Op != OXCASE {
Fatalf("switch string(byteslice) bad op: %v", ncase.Op)
}
cc.list = append(cc.list, c)
for _, v := range ncase.List.Slice() {
if v.Op != OLITERAL {
return false
}
if cc.defjmp == nil {
cc.defjmp = nod(OBREAK, nil, nil)
}
return cc
}
return true
}
// walk generates an AST that implements sw,
// where sw is a type switch.
// The AST is generally of the form of a linear
// search using if..goto, although binary search
// is used with long runs of concrete types.
func (s *typeSwitch) walk(sw *Node) {
cond := sw.Left
sw.Left = nil
if cond == nil {
sw.List.Set(nil)
return
}
if cond.Right == nil {
yyerrorl(sw.Pos, "type switch must have an assignment")
return
}
// hasFall reports whether stmts ends with a "fallthrough" statement.
func hasFall(stmts []*Node) bool {
// Search backwards for the index of the fallthrough
// statement. Do not assume it'll be in the last
// position, since in some cases (e.g. when the statement
// list contains autotmp_ variables), one or more OVARKILL
// nodes will be at the end of the list.
cond.Right = walkexpr(cond.Right, &sw.Ninit)
if !cond.Right.Type.IsInterface() {
yyerrorl(sw.Pos, "type switch must be on an interface")
return
i := len(stmts) - 1
for i >= 0 && stmts[i].Op == OVARKILL {
i--
}
return i >= 0 && stmts[i].Op == OFALL
}
var cas []*Node
// predeclare temporary variables and the boolean var
s.facename = temp(cond.Right.Type)
a := nod(OAS, s.facename, cond.Right)
a = typecheck(a, ctxStmt)
cas = append(cas, a)
// walkTypeSwitch generates an AST that implements sw, where sw is a
// type switch.
func walkTypeSwitch(sw *Node) {
var s typeSwitch
s.facename = sw.Left.Right
sw.Left = nil
s.facename = walkexpr(s.facename, &sw.Ninit)
s.facename = copyexpr(s.facename, s.facename.Type, &sw.Nbody)
s.okname = temp(types.Types[TBOOL])
s.okname = typecheck(s.okname, ctxExpr)
s.hashname = temp(types.Types[TUINT32])
s.hashname = typecheck(s.hashname, ctxExpr)
// set up labels and jumps
casebody(sw, s.facename)
clauses := s.genCaseClauses(sw.List.Slice())
sw.List.Set(nil)
def := clauses.defjmp
// Get interface descriptor word.
// For empty interfaces this will be the type.
// For non-empty interfaces this will be the itab.
itab := nod(OITAB, s.facename, nil)
// For empty interfaces, do:
// if e._type == nil {
......@@ -688,230 +487,235 @@ func (s *typeSwitch) walk(sw *Node) {
// }
// h := e._type.hash
// Use a similar strategy for non-empty interfaces.
ifNil := nod(OIF, nil, nil)
ifNil.Left = nod(OEQ, itab, nodnil())
ifNil.Left = typecheck(ifNil.Left, ctxExpr)
ifNil.Left = defaultlit(ifNil.Left, nil)
// ifNil.Nbody assigned at end.
sw.Nbody.Append(ifNil)
// Get interface descriptor word.
// For empty interfaces this will be the type.
// For non-empty interfaces this will be the itab.
itab := nod(OITAB, s.facename, nil)
// Check for nil first.
i := nod(OIF, nil, nil)
i.Left = nod(OEQ, itab, nodnil())
if clauses.niljmp != nil {
// Do explicit nil case right here.
i.Nbody.Set1(clauses.niljmp)
// Load hash from type or itab.
dotHash := nodSym(ODOTPTR, itab, nil)
dotHash.Type = types.Types[TUINT32]
dotHash.SetTypecheck(1)
if s.facename.Type.IsEmptyInterface() {
dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime._type
} else {
// Jump to default case.
lbl := autolabel(".s")
i.Nbody.Set1(nodSym(OGOTO, nil, lbl))
// Wrap default case with label.
blk := nod(OBLOCK, nil, nil)
blk.List.Set2(nodSym(OLABEL, nil, lbl), def)
def = blk
dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime.itab
}
i.Left = typecheck(i.Left, ctxExpr)
i.Left = defaultlit(i.Left, nil)
cas = append(cas, i)
dotHash.SetBounded(true) // guaranteed not to fault
s.hashname = copyexpr(dotHash, dotHash.Type, &sw.Nbody)
// Load hash from type or itab.
h := nodSym(ODOTPTR, itab, nil)
h.Type = types.Types[TUINT32]
h.SetTypecheck(1)
if cond.Right.Type.IsEmptyInterface() {
h.Xoffset = int64(2 * Widthptr) // offset of hash in runtime._type
} else {
h.Xoffset = int64(2 * Widthptr) // offset of hash in runtime.itab
br := nod(OBREAK, nil, nil)
var defaultGoto, nilGoto *Node
var body Nodes
for _, ncase := range sw.List.Slice() {
var caseVar *Node
if ncase.Rlist.Len() != 0 {
caseVar = ncase.Rlist.First()
}
h.SetBounded(true) // guaranteed not to fault
a = nod(OAS, s.hashname, h)
a = typecheck(a, ctxStmt)
cas = append(cas, a)
cc := clauses.list
// For single-type cases, we initialize the case
// variable as part of the type assertion; but in
// other cases, we initialize it in the body.
singleType := ncase.List.Len() == 1 && ncase.List.First().Op == OTYPE
// insert type equality check into each case block
for _, c := range cc {
c.node.Right = s.typeone(c.node)
label := autolabel(".s")
jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label))
if ncase.List.Len() == 0 { // default:
if defaultGoto != nil {
Fatalf("duplicate default case not detected during typechecking")
}
defaultGoto = jmp
}
// generate list of if statements, binary search for constant sequences
for len(cc) > 0 {
if !cc[0].isconst {
n := cc[0].node
cas = append(cas, n.Right)
cc = cc[1:]
for _, n1 := range ncase.List.Slice() {
if n1.isNil() { // case nil:
if nilGoto != nil {
Fatalf("duplicate nil case not detected during typechecking")
}
nilGoto = jmp
continue
}
// identify run of constants
var run int
for run = 1; run < len(cc) && cc[run].isconst; run++ {
if singleType {
s.Add(n1.Type, caseVar, jmp)
} else {
s.Add(n1.Type, nil, jmp)
}
}
// sort by hash
sort.Sort(caseClauseByType(cc[:run]))
// for debugging: linear search
if false {
for i := 0; i < run; i++ {
n := cc[i].node
cas = append(cas, n.Right)
body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label)))
if caseVar != nil && !singleType {
l := []*Node{
nodl(ncase.Pos, ODCL, caseVar, nil),
nodl(ncase.Pos, OAS, caseVar, s.facename),
}
continue
typecheckslice(l, ctxStmt)
body.Append(l...)
}
// combine adjacent cases with the same hash
var batch []caseClause
for i, j := 0, 0; i < run; i = j {
hash := []*Node{cc[i].node.Right}
for j = i + 1; j < run && cc[i].hash == cc[j].hash; j++ {
hash = append(hash, cc[j].node.Right)
body.Append(ncase.Nbody.Slice()...)
body.Append(br)
}
cc[i].node.Right = liststmt(hash)
batch = append(batch, cc[i])
sw.List.Set(nil)
if defaultGoto == nil {
defaultGoto = br
}
// binary search among cases to narrow by hash
cas = append(cas, s.walkCases(batch))
cc = cc[run:]
if nilGoto != nil {
ifNil.Nbody.Set1(nilGoto)
} else {
// TODO(mdempsky): Just use defaultGoto directly.
// Jump to default case.
label := autolabel(".s")
ifNil.Nbody.Set1(nodSym(OGOTO, nil, label))
// Wrap default case with label.
blk := nod(OBLOCK, nil, nil)
blk.List.Set2(nodSym(OLABEL, nil, label), defaultGoto)
defaultGoto = blk
}
// handle default case
if nerrors == 0 {
cas = append(cas, def)
sw.Nbody.Prepend(cas...)
sw.List.Set(nil)
s.Emit(&sw.Nbody)
sw.Nbody.Append(defaultGoto)
sw.Nbody.AppendNodes(&body)
walkstmtlist(sw.Nbody.Slice())
}
}
// typeone generates an AST that jumps to the
// case body if the variable is of type t.
func (s *typeSwitch) typeone(t *Node) *Node {
var name *Node
var init Nodes
if t.Rlist.Len() == 0 {
name = nblank
nblank = typecheck(nblank, ctxExpr|ctxAssign)
} else {
name = t.Rlist.First()
init.Append(nod(ODCL, name, nil))
a := nod(OAS, name, nil)
a = typecheck(a, ctxStmt)
init.Append(a)
}
a := nod(OAS2, nil, nil)
a.List.Set2(name, s.okname) // name, ok =
b := nod(ODOTTYPE, s.facename, nil)
b.Type = t.Left.Type // interface.(type)
a.Rlist.Set1(b)
a = typecheck(a, ctxStmt)
a = walkexpr(a, &init)
init.Append(a)
c := nod(OIF, nil, nil)
c.Left = s.okname
c.Nbody.Set1(t.Right) // if ok { goto l }
init.Append(c)
return init.asblock()
// A typeSwitch walks a type switch.
type typeSwitch struct {
// Temporary variables (i.e., ONAMEs) used by type switch dispatch logic:
facename *Node // value being type-switched on
hashname *Node // type hash of the value being type-switched on
okname *Node // boolean used for comma-ok type assertions
done Nodes
clauses []typeClause
}
// walkCases generates an AST implementing the cases in cc.
func (s *typeSwitch) walkCases(cc []caseClause) *Node {
if len(cc) < binarySearchMin {
var cas []*Node
for _, c := range cc {
n := c.node
if !c.isconst {
Fatalf("typeSwitch walkCases")
}
a := nod(OIF, nil, nil)
a.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash)))
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(n.Right)
cas = append(cas, a)
type typeClause struct {
hash uint32
body Nodes
}
func (s *typeSwitch) Add(typ *types.Type, caseVar *Node, jmp *Node) {
var body Nodes
if caseVar != nil {
l := []*Node{
nod(ODCL, caseVar, nil),
nod(OAS, caseVar, nil),
}
return liststmt(cas)
typecheckslice(l, ctxStmt)
body.Append(l...)
} else {
caseVar = nblank
}
// cv, ok = iface.(type)
as := nod(OAS2, nil, nil)
as.List.Set2(caseVar, s.okname) // cv, ok =
dot := nod(ODOTTYPE, s.facename, nil)
dot.Type = typ // iface.(type)
as.Rlist.Set1(dot)
as = typecheck(as, ctxStmt)
as = walkexpr(as, &body)
body.Append(as)
// if ok { goto label }
nif := nod(OIF, nil, nil)
nif.Left = s.okname
nif.Nbody.Set1(jmp)
body.Append(nif)
if !typ.IsInterface() {
s.clauses = append(s.clauses, typeClause{
hash: typehash(typ),
body: body,
})
return
}
// find the middle and recur
half := len(cc) / 2
a := nod(OIF, nil, nil)
a.Left = nod(OLE, s.hashname, nodintconst(int64(cc[half-1].hash)))
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.Set1(s.walkCases(cc[:half]))
a.Rlist.Set1(s.walkCases(cc[half:]))
return a
s.flush()
s.done.AppendNodes(&body)
}
// caseClauseByConstVal sorts clauses by constant value to enable binary search.
type caseClauseByConstVal []caseClause
func (x caseClauseByConstVal) Len() int { return len(x) }
func (x caseClauseByConstVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x caseClauseByConstVal) Less(i, j int) bool {
// n1 and n2 might be individual constants or integer ranges.
// We have checked for duplicates already,
// so ranges can be safely represented by any value in the range.
n1 := x[i].node
var v1 interface{}
if s := n1.List.Slice(); s != nil {
v1 = s[0].Val().U
} else {
v1 = n1.Left.Val().U
}
func (s *typeSwitch) Emit(out *Nodes) {
s.flush()
out.AppendNodes(&s.done)
}
n2 := x[j].node
var v2 interface{}
if s := n2.List.Slice(); s != nil {
v2 = s[0].Val().U
} else {
v2 = n2.Left.Val().U
func (s *typeSwitch) flush() {
cc := s.clauses
s.clauses = nil
if len(cc) == 0 {
return
}
switch v1 := v1.(type) {
case *Mpflt:
return v1.Cmp(v2.(*Mpflt)) < 0
case *Mpint:
return v1.Cmp(v2.(*Mpint)) < 0
case string:
// Sort strings by length and then by value.
// It is much cheaper to compare lengths than values,
// and all we need here is consistency.
// We respect this sorting in exprSwitch.walkCases.
a := v1
b := v2.(string)
if len(a) != len(b) {
return len(a) < len(b)
sort.Slice(cc, func(i, j int) bool { return cc[i].hash < cc[j].hash })
// Combine adjacent cases with the same hash.
merged := cc[:1]
for _, c := range cc[1:] {
last := &merged[len(merged)-1]
if last.hash == c.hash {
last.body.AppendNodes(&c.body)
} else {
merged = append(merged, c)
}
return a < b
}
cc = merged
Fatalf("caseClauseByConstVal passed bad clauses %v < %v", x[i].node.Left, x[j].node.Left)
return false
binarySearch(len(cc), &s.done,
func(i int) *Node {
return nod(OLE, s.hashname, nodintconst(int64(cc[i-1].hash)))
},
func(i int, out *Nodes) {
// TODO(mdempsky): Omit hash equality check if
// there's only one type.
c := cc[i]
a := nod(OIF, nil, nil)
a.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash)))
a.Left = typecheck(a.Left, ctxExpr)
a.Left = defaultlit(a.Left, nil)
a.Nbody.AppendNodes(&c.body)
out.Append(a)
},
)
}
type caseClauseByType []caseClause
func (x caseClauseByType) Len() int { return len(x) }
func (x caseClauseByType) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x caseClauseByType) Less(i, j int) bool {
c1, c2 := x[i], x[j]
// sort by hash code, then ordinal (for the rare case of hash collisions)
if c1.hash != c2.hash {
return c1.hash < c2.hash
// binarySearch constructs a binary search tree for handling n cases,
// and appends it to out. It's used for efficiently implementing
// switch statements.
//
// less(i) should return a boolean expression. If it evaluates true,
// then cases [0, i) will be tested; otherwise, cases [i, n).
//
// base(i, out) should append statements to out to test the i'th case.
func binarySearch(n int, out *Nodes, less func(i int) *Node, base func(i int, out *Nodes)) {
const binarySearchMin = 4 // minimum number of cases for binary search
var do func(lo, hi int, out *Nodes)
do = func(lo, hi int, out *Nodes) {
n := hi - lo
if n < binarySearchMin {
for i := lo; i < hi; i++ {
base(i, out)
}
return
}
return c1.ordinal < c2.ordinal
}
type constIntNodesByVal []*Node
half := lo + n/2
nif := nod(OIF, nil, nil)
nif.Left = less(half)
nif.Left = typecheck(nif.Left, ctxExpr)
nif.Left = defaultlit(nif.Left, nil)
do(lo, half, &nif.Nbody)
do(half, hi, &nif.Rlist)
out.Append(nif)
}
func (x constIntNodesByVal) Len() int { return len(x) }
func (x constIntNodesByVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x constIntNodesByVal) Less(i, j int) bool {
return x[i].Val().U.(*Mpint).Cmp(x[j].Val().U.(*Mpint)) < 0
do(0, n, out)
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gc
import (
"testing"
)
func nodrune(r rune) *Node {
v := new(Mpint)
v.SetInt64(int64(r))
v.Rune = true
return nodlit(Val{v})
}
func nodflt(f float64) *Node {
v := newMpflt()
v.SetFloat64(f)
return nodlit(Val{v})
}
func TestCaseClauseByConstVal(t *testing.T) {
tests := []struct {
a, b *Node
}{
// CTFLT
{nodflt(0.1), nodflt(0.2)},
// CTINT
{nodintconst(0), nodintconst(1)},
// CTRUNE
{nodrune('a'), nodrune('b')},
// CTSTR
{nodlit(Val{"ab"}), nodlit(Val{"abc"})},
{nodlit(Val{"ab"}), nodlit(Val{"xyz"})},
{nodlit(Val{"abc"}), nodlit(Val{"xyz"})},
}
for i, test := range tests {
a := caseClause{node: nod(OXXX, test.a, nil)}
b := caseClause{node: nod(OXXX, test.b, nil)}
s := caseClauseByConstVal{a, b}
if less := s.Less(0, 1); !less {
t.Errorf("%d: caseClauseByConstVal(%v, %v) = false", i, test.a, test.b)
}
if less := s.Less(1, 0); less {
t.Errorf("%d: caseClauseByConstVal(%v, %v) = true", i, test.a, test.b)
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment