Commit 1a53915c authored by Daniel Martí's avatar Daniel Martí

cmd/compile: initial rulegen rewrite

rulegen.go produces plaintext Go code directly, which was fine for a
while. However, that's started being a bottleneck for making code
generation more complex, as we can only generate code directly one line
at a time.

Some workarounds were used, like multiple layers of buffers to generate
chunks of code, to then use strings.Contains to see whether variables
need to be defined or not. However, that's error-prone, verbose, and
difficult to work with.

A better approach is to generate an intermediate syntax tree in memory,
which we can inspect and modify easily. For example, we could run a
number of "passes" on the syntax tree before writing to disk, such as
removing unused variables, simplifying logic, or moving declarations
closer to their uses.

This is the first step in that direction, without changing any of the
generated code. We didn't use go/ast directly, as it's too complex for
our needs. In particular, we only need a few kinds of simple statements,
but we do want to support arbitrary expressions. As such, define a
simple set of statement structs, and add thin layers for printer.Fprint
and ast.Inspect.

A nice side effect of this change, besides removing some buffers and
string handling, is that we can now avoid passing so many parameters
around. And, while we add over a hundred lines of code, the tricky
pieces of code are now a bit simpler to follow.

While at it, apply some cleanups, such as replacing isVariable with
token.IsIdentifier, and consistently using log.Fatalf.

Follow-up CLs will start improving the generated code, also simplifying
the rulegen code itself. I've added some TODOs for the low-hanging fruit
that I intend to work on right after.

Updates #30810.

Change-Id: Ic371c192b29c85dfc4a001be7fbcbeec85facc9d
Reviewed-on: https://go-review.googlesource.com/c/go/+/177539
Run-TryBot: Daniel Martí <mvdan@mvdan.cc>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarKeith Randall <khr@golang.org>
parent cd33d271
......@@ -16,10 +16,15 @@ import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/printer"
"go/token"
"io"
"io/ioutil"
"log"
"math"
"os"
"regexp"
"sort"
......@@ -47,9 +52,7 @@ import (
// If multiple rules match, the first one in file order is selected.
var (
genLog = flag.Bool("log", false, "generate code that logs; for debugging only")
)
var genLog = flag.Bool("log", false, "generate code that logs; for debugging only")
type Rule struct {
rule string
......@@ -136,7 +139,8 @@ func genRulesSuffix(arch arch, suff string) {
r := Rule{rule: rule3, loc: loc}
if rawop := strings.Split(rule3, " ")[0][1:]; isBlock(rawop, arch) {
blockrules[rawop] = append(blockrules[rawop], r)
} else {
continue
}
// Do fancier value op matching.
match, _, _ := r.parse()
op, oparch, _, _, _, _ := parseValue(match, arch, loc)
......@@ -144,7 +148,6 @@ func genRulesSuffix(arch arch, suff string) {
oprules[opname] = append(oprules[opname], r)
}
}
}
rule = ""
ruleLineno = 0
}
......@@ -162,159 +165,378 @@ func genRulesSuffix(arch arch, suff string) {
}
sort.Strings(ops)
// Start output buffer, write header.
w := new(bytes.Buffer)
fmt.Fprintf(w, "// Code generated from gen/%s%s.rules; DO NOT EDIT.\n", arch.name, suff)
fmt.Fprintln(w, "// generated with: cd gen; go run *.go")
fmt.Fprintln(w)
fmt.Fprintln(w, "package ssa")
fmt.Fprintln(w, "import \"fmt\"")
fmt.Fprintln(w, "import \"math\"")
fmt.Fprintln(w, "import \"cmd/internal/obj\"")
fmt.Fprintln(w, "import \"cmd/internal/objabi\"")
fmt.Fprintln(w, "import \"cmd/compile/internal/types\"")
fmt.Fprintln(w, "var _ = fmt.Println // in case not otherwise used")
fmt.Fprintln(w, "var _ = math.MinInt8 // in case not otherwise used")
fmt.Fprintln(w, "var _ = obj.ANOP // in case not otherwise used")
fmt.Fprintln(w, "var _ = objabi.GOROOT // in case not otherwise used")
fmt.Fprintln(w, "var _ = types.TypeMem // in case not otherwise used")
fmt.Fprintln(w)
file := &File{arch: arch, suffix: suff}
const chunkSize = 10
// Main rewrite routine is a switch on v.Op.
fmt.Fprintf(w, "func rewriteValue%s%s(v *Value) bool {\n", arch.name, suff)
fmt.Fprintf(w, "switch v.Op {\n")
fn := &Func{kind: "Value"}
sw := &Switch{expr: exprf("v.Op")}
for _, op := range ops {
fmt.Fprintf(w, "case %s:\n", op)
fmt.Fprint(w, "return ")
var ors []string
for chunk := 0; chunk < len(oprules[op]); chunk += chunkSize {
if chunk > 0 {
fmt.Fprint(w, " || ")
ors = append(ors, fmt.Sprintf("rewriteValue%s%s_%s_%d(v)", arch.name, suff, op, chunk))
}
fmt.Fprintf(w, "rewriteValue%s%s_%s_%d(v)", arch.name, suff, op, chunk)
}
fmt.Fprintln(w)
swc := &Case{expr: exprf(op)}
swc.add(stmtf("return %s", strings.Join(ors, " || ")))
sw.add(swc)
}
fmt.Fprintf(w, "}\n")
fmt.Fprintf(w, "return false\n")
fmt.Fprintf(w, "}\n")
fn.add(sw)
fn.add(stmtf("return false"))
file.add(fn)
// Generate a routine per op. Note that we don't make one giant routine
// because it is too big for some compilers.
for _, op := range ops {
for chunk := 0; chunk < len(oprules[op]); chunk += chunkSize {
buf := new(bytes.Buffer)
var canFail bool
rules := oprules[op]
// rr is kept between chunks, so that a following chunk checks
// that the previous one ended with a rule that wasn't
// unconditional.
var rr *RuleRewrite
for chunk := 0; chunk < len(rules); chunk += chunkSize {
endchunk := chunk + chunkSize
if endchunk > len(oprules[op]) {
endchunk = len(oprules[op])
}
for i, rule := range oprules[op][chunk:endchunk] {
match, cond, result := rule.parse()
fmt.Fprintf(buf, "// match: %s\n", match)
fmt.Fprintf(buf, "// cond: %s\n", cond)
fmt.Fprintf(buf, "// result: %s\n", result)
canFail = false
fmt.Fprintf(buf, "for {\n")
pos, _, matchCanFail := genMatch(buf, arch, match, rule.loc)
if pos == "" {
pos = "v.Pos"
if endchunk > len(rules) {
endchunk = len(rules)
}
if matchCanFail {
canFail = true
fn := &Func{
kind: "Value",
suffix: fmt.Sprintf("_%s_%d", op, chunk),
}
if cond != "" {
fmt.Fprintf(buf, "if !(%s) {\nbreak\n}\n", cond)
canFail = true
var rewrites bodyBase
for _, rule := range rules[chunk:endchunk] {
if rr != nil && !rr.canFail {
log.Fatalf("unconditional rule %s is followed by other rules", rr.match)
}
rr = &RuleRewrite{loc: rule.loc}
rr.match, rr.cond, rr.result = rule.parse()
pos, _ := genMatch(rr, arch, rr.match)
if pos == "" {
pos = "v.Pos"
}
if !canFail && i+chunk != len(oprules[op])-1 {
log.Fatalf("unconditional rule %s is followed by other rules", match)
if rr.cond != "" {
rr.add(breakf("!(%s)", rr.cond))
}
genResult(buf, arch, result, rule.loc, pos)
genResult(rr, arch, rr.result, pos)
if *genLog {
fmt.Fprintf(buf, "logRule(\"%s\")\n", rule.loc)
rr.add(stmtf("logRule(%q)", rule.loc))
}
rewrites.add(rr)
}
fmt.Fprintf(buf, "return true\n")
fmt.Fprintf(buf, "}\n")
// TODO(mvdan): remove unused vars later instead
uses := make(map[string]int)
walk(&rewrites, func(node Node) {
switch node := node.(type) {
case *Declare:
// work around var shadowing
// TODO(mvdan): forbid it instead.
uses[node.name] = math.MinInt32
case *ast.Ident:
uses[node.Name]++
}
if canFail {
fmt.Fprintf(buf, "return false\n")
})
if uses["b"]+uses["config"]+uses["fe"]+uses["typ"] > 0 {
fn.add(declf("b", "v.Block"))
}
body := buf.String()
// Figure out whether we need b, config, fe, and/or types; provide them if so.
hasb := strings.Contains(body, " b.")
hasconfig := strings.Contains(body, "config.") || strings.Contains(body, "config)")
hasfe := strings.Contains(body, "fe.")
hastyps := strings.Contains(body, "typ.")
fmt.Fprintf(w, "func rewriteValue%s%s_%s_%d(v *Value) bool {\n", arch.name, suff, op, chunk)
if hasb || hasconfig || hasfe || hastyps {
fmt.Fprintln(w, "b := v.Block")
if uses["config"] > 0 {
fn.add(declf("config", "b.Func.Config"))
}
if hasconfig {
fmt.Fprintln(w, "config := b.Func.Config")
if uses["fe"] > 0 {
fn.add(declf("fe", "b.Func.fe"))
}
if hasfe {
fmt.Fprintln(w, "fe := b.Func.fe")
if uses["typ"] > 0 {
fn.add(declf("typ", "&b.Func.Config.Types"))
}
if hastyps {
fmt.Fprintln(w, "typ := &b.Func.Config.Types")
fn.add(rewrites.list...)
if rr.canFail {
fn.add(stmtf("return false"))
}
fmt.Fprint(w, body)
fmt.Fprintf(w, "}\n")
file.add(fn)
}
}
// Generate block rewrite function. There are only a few block types
// so we can make this one function with a switch.
fmt.Fprintf(w, "func rewriteBlock%s%s(b *Block) bool {\n", arch.name, suff)
fmt.Fprintln(w, "config := b.Func.Config")
fmt.Fprintln(w, "typ := &config.Types")
fmt.Fprintln(w, "_ = typ")
fmt.Fprintln(w, "v := b.Control")
fmt.Fprintln(w, "_ = v")
fmt.Fprintf(w, "switch b.Kind {\n")
ops = nil
fn = &Func{kind: "Block"}
fn.add(declf("config", "b.Func.Config"))
// TODO(mvdan): declare these only if needed
fn.add(declf("typ", "&config.Types"))
fn.add(stmtf("_ = typ"))
fn.add(declf("v", "b.Control"))
fn.add(stmtf("_ = v"))
sw = &Switch{expr: exprf("b.Kind")}
ops = ops[:0]
for op := range blockrules {
ops = append(ops, op)
}
sort.Strings(ops)
for _, op := range ops {
fmt.Fprintf(w, "case %s:\n", blockName(op, arch))
swc := &Case{expr: exprf("%s", blockName(op, arch))}
for _, rule := range blockrules[op] {
match, cond, result := rule.parse()
fmt.Fprintf(w, "// match: %s\n", match)
fmt.Fprintf(w, "// cond: %s\n", cond)
fmt.Fprintf(w, "// result: %s\n", result)
swc.add(genBlockRewrite(rule, arch))
}
sw.add(swc)
}
fn.add(sw)
fn.add(stmtf("return false"))
file.add(fn)
// gofmt result
buf := new(bytes.Buffer)
fprint(buf, file)
b := buf.Bytes()
src, err := format.Source(b)
if err != nil {
fmt.Printf("%s\n", b)
panic(err)
}
// Write to file
if err := ioutil.WriteFile("../rewrite"+arch.name+suff+".go", src, 0666); err != nil {
log.Fatalf("can't write output: %v\n", err)
}
}
_, _, _, aux, s := extract(match) // remove parens, then split
func walk(node Node, fn func(Node)) {
fn(node)
switch node := node.(type) {
case *bodyBase:
case *File:
case *Func:
case *Switch:
walk(node.expr, fn)
case *Case:
walk(node.expr, fn)
case *RuleRewrite:
case *Declare:
walk(node.value, fn)
case *CondBreak:
walk(node.expr, fn)
case ast.Node:
ast.Inspect(node, func(node ast.Node) bool {
fn(node)
return true
})
default:
log.Fatalf("cannot walk %T", node)
}
if wb, ok := node.(interface{ body() []Statement }); ok {
for _, node := range wb.body() {
walk(node, fn)
}
}
fn(nil)
}
loopw := new(bytes.Buffer)
func fprint(w io.Writer, n Node) {
switch n := n.(type) {
case *File:
fmt.Fprintf(w, "// Code generated from gen/%s%s.rules; DO NOT EDIT.\n", n.arch.name, n.suffix)
fmt.Fprintf(w, "// generated with: cd gen; go run *.go\n")
fmt.Fprintf(w, "\npackage ssa\n")
// TODO(mvdan): keep the needed imports only
fmt.Fprintln(w, "import \"fmt\"")
fmt.Fprintln(w, "import \"math\"")
fmt.Fprintln(w, "import \"cmd/internal/obj\"")
fmt.Fprintln(w, "import \"cmd/internal/objabi\"")
fmt.Fprintln(w, "import \"cmd/compile/internal/types\"")
fmt.Fprintln(w, "var _ = fmt.Println // in case not otherwise used")
fmt.Fprintln(w, "var _ = math.MinInt8 // in case not otherwise used")
fmt.Fprintln(w, "var _ = obj.ANOP // in case not otherwise used")
fmt.Fprintln(w, "var _ = objabi.GOROOT // in case not otherwise used")
fmt.Fprintln(w, "var _ = types.TypeMem // in case not otherwise used")
fmt.Fprintln(w)
for _, f := range n.list {
f := f.(*Func)
fmt.Fprintf(w, "func rewrite%s%s%s%s(", f.kind, n.arch.name, n.suffix, f.suffix)
fmt.Fprintf(w, "%c *%s) bool {\n", strings.ToLower(f.kind)[0], f.kind)
for _, n := range f.list {
fprint(w, n)
}
fmt.Fprintf(w, "}\n")
}
case *Switch:
fmt.Fprintf(w, "switch ")
fprint(w, n.expr)
fmt.Fprintf(w, " {\n")
for _, n := range n.list {
fprint(w, n)
}
fmt.Fprintf(w, "}\n")
case *Case:
fmt.Fprintf(w, "case ")
fprint(w, n.expr)
fmt.Fprintf(w, ":\n")
for _, n := range n.list {
fprint(w, n)
}
case *RuleRewrite:
fmt.Fprintf(w, "// match: %s\n", n.match)
fmt.Fprintf(w, "// cond: %s\n", n.cond)
fmt.Fprintf(w, "// result: %s\n", n.result)
if n.checkOp != "" {
fmt.Fprintf(w, "for v.Op == %s {\n", n.checkOp)
} else {
fmt.Fprintf(w, "for {\n")
}
for _, n := range n.list {
fprint(w, n)
}
fmt.Fprintf(w, "return true\n}\n")
case *Declare:
fmt.Fprintf(w, "%s := ", n.name)
fprint(w, n.value)
fmt.Fprintln(w)
case *CondBreak:
fmt.Fprintf(w, "if ")
fprint(w, n.expr)
fmt.Fprintf(w, " {\nbreak\n}\n")
case ast.Node:
printer.Fprint(w, emptyFset, n)
if _, ok := n.(ast.Stmt); ok {
fmt.Fprintln(w)
}
default:
log.Fatalf("cannot print %T", n)
}
}
var emptyFset = token.NewFileSet()
// Node can be a Statement or an ast.Expr.
type Node interface{}
// Statement can be one of our high-level statement struct types, or an
// ast.Stmt under some limited circumstances.
type Statement interface{}
// bodyBase is shared by all of our statement psuedo-node types which can
// contain other statements.
type bodyBase struct {
list []Statement
canFail bool
}
func (w *bodyBase) body() []Statement { return w.list }
func (w *bodyBase) add(nodes ...Statement) {
w.list = append(w.list, nodes...)
for _, node := range nodes {
if _, ok := node.(*CondBreak); ok {
w.canFail = true
}
}
}
// declared reports if the body contains a Declare with the given name.
func (w *bodyBase) declared(name string) bool {
for _, s := range w.list {
if decl, ok := s.(*Declare); ok && decl.name == name {
return true
}
}
return false
}
// These types define some high-level statement struct types, which can be used
// as a Statement. This allows us to keep some node structs simpler, and have
// higher-level nodes such as an entire rule rewrite.
//
// Note that ast.Expr is always used as-is; we don't declare our own expression
// nodes.
type (
File struct {
bodyBase // []*Func
arch arch
suffix string
}
Func struct {
bodyBase
kind string // "Value" or "Block"
suffix string
}
Switch struct {
bodyBase // []*Case
expr ast.Expr
}
Case struct {
bodyBase
expr ast.Expr
}
RuleRewrite struct {
bodyBase
match, cond, result string // top comments
checkOp string
alloc int // for unique var names
loc string // file name & line number of the original rule
}
Declare struct {
name string
value ast.Expr
}
CondBreak struct {
expr ast.Expr
}
)
// exprf parses a Go expression generated from fmt.Sprintf, panicking if an
// error occurs.
func exprf(format string, a ...interface{}) ast.Expr {
src := fmt.Sprintf(format, a...)
expr, err := parser.ParseExpr(src)
if err != nil {
log.Fatalf("expr parse error on %q: %v", src, err)
}
return expr
}
// stmtf parses a Go statement generated from fmt.Sprintf. This function is only
// meant for simple statements that don't have a custom Statement node declared
// in this package, such as ast.ReturnStmt or ast.ExprStmt.
func stmtf(format string, a ...interface{}) Statement {
src := fmt.Sprintf(format, a...)
fsrc := "package p\nfunc _() {\n" + src + "\n}\n"
file, err := parser.ParseFile(token.NewFileSet(), "", fsrc, 0)
if err != nil {
log.Fatalf("stmt parse error on %q: %v", src, err)
}
return file.Decls[0].(*ast.FuncDecl).Body.List[0]
}
// declf constructs a simple "name := value" declaration, using exprf for its
// value.
func declf(name, format string, a ...interface{}) *Declare {
return &Declare{name, exprf(format, a...)}
}
// breakf constructs a simple "if cond { break }" statement, using exprf for its
// condition.
func breakf(format string, a ...interface{}) *CondBreak {
return &CondBreak{exprf(format, a...)}
}
func genBlockRewrite(rule Rule, arch arch) *RuleRewrite {
rr := &RuleRewrite{loc: rule.loc}
rr.match, rr.cond, rr.result = rule.parse()
_, _, _, aux, s := extract(rr.match) // remove parens, then split
// check match of control value
pos := ""
checkOp := ""
if s[0] != "nil" {
if strings.Contains(s[0], "(") {
pos, checkOp, _ = genMatch0(loopw, arch, s[0], "v", map[string]struct{}{}, rule.loc)
pos, rr.checkOp = genMatch0(rr, arch, s[0], "v")
} else {
fmt.Fprintf(loopw, "%s := b.Control\n", s[0])
rr.add(declf(s[0], "b.Control"))
}
}
if aux != "" {
fmt.Fprintf(loopw, "%s := b.Aux\n", aux)
rr.add(declf(aux, "b.Aux"))
}
if cond != "" {
fmt.Fprintf(loopw, "if !(%s) {\nbreak\n}\n", cond)
if rr.cond != "" {
rr.add(breakf("!(%s)", rr.cond))
}
// Rule matches. Generate result.
outop, _, _, aux, t := extract(result) // remove parens, then split
outop, _, _, aux, t := extract(rr.result) // remove parens, then split
newsuccs := t[1:]
// Check if newsuccs is the same set as succs.
......@@ -336,19 +558,20 @@ func genRulesSuffix(arch arch, suff string) {
log.Fatalf("unmatched successors %v in %s", m, rule)
}
fmt.Fprintf(loopw, "b.Kind = %s\n", blockName(outop, arch))
rr.add(stmtf("b.Kind = %s", blockName(outop, arch)))
if t[0] == "nil" {
fmt.Fprintf(loopw, "b.SetControl(nil)\n")
rr.add(stmtf("b.SetControl(nil)"))
} else {
if pos == "" {
pos = "v.Pos"
}
fmt.Fprintf(loopw, "b.SetControl(%s)\n", genResult0(loopw, arch, t[0], new(int), false, false, rule.loc, pos))
v := genResult0(rr, arch, t[0], false, false, pos)
rr.add(stmtf("b.SetControl(%s)", v))
}
if aux != "" {
fmt.Fprintf(loopw, "b.Aux = %s\n", aux)
rr.add(stmtf("b.Aux = %s", aux))
} else {
fmt.Fprintln(loopw, "b.Aux = nil")
rr.add(stmtf("b.Aux = nil"))
}
succChanged := false
......@@ -364,54 +587,26 @@ func genRulesSuffix(arch arch, suff string) {
if succs[0] != newsuccs[1] || succs[1] != newsuccs[0] {
log.Fatalf("can only handle swapped successors in %s", rule)
}
fmt.Fprintln(loopw, "b.swapSuccessors()")
rr.add(stmtf("b.swapSuccessors()"))
}
if *genLog {
fmt.Fprintf(loopw, "logRule(\"%s\")\n", rule.loc)
}
fmt.Fprintf(loopw, "return true\n")
if checkOp != "" {
fmt.Fprintf(w, "for v.Op == %s {\n", checkOp)
} else {
fmt.Fprintf(w, "for {\n")
}
io.Copy(w, loopw)
fmt.Fprintf(w, "}\n")
}
}
fmt.Fprintf(w, "}\n")
fmt.Fprintf(w, "return false\n")
fmt.Fprintf(w, "}\n")
// gofmt result
b := w.Bytes()
src, err := format.Source(b)
if err != nil {
fmt.Printf("%s\n", b)
panic(err)
}
// Write to file
err = ioutil.WriteFile("../rewrite"+arch.name+suff+".go", src, 0666)
if err != nil {
log.Fatalf("can't write output: %v\n", err)
rr.add(stmtf("logRule(%q)", rule.loc))
}
return rr
}
// genMatch returns the variable whose source position should be used for the
// result (or "" if no opinion), and a boolean that reports whether the match can fail.
func genMatch(w io.Writer, arch arch, match string, loc string) (pos, checkOp string, canFail bool) {
return genMatch0(w, arch, match, "v", map[string]struct{}{}, loc)
func genMatch(rr *RuleRewrite, arch arch, match string) (pos, checkOp string) {
return genMatch0(rr, arch, match, "v")
}
func genMatch0(w io.Writer, arch arch, match, v string, m map[string]struct{}, loc string) (pos, checkOp string, canFail bool) {
func genMatch0(rr *RuleRewrite, arch arch, match, v string) (pos, checkOp string) {
if match[0] != '(' || match[len(match)-1] != ')' {
panic("non-compound expr in genMatch0: " + match)
log.Fatalf("non-compound expr in genMatch0: %q", match)
}
op, oparch, typ, auxint, aux, args := parseValue(match, arch, loc)
op, oparch, typ, auxint, aux, args := parseValue(match, arch, rr.loc)
checkOp = fmt.Sprintf("Op%s%s", oparch, op.name)
......@@ -421,68 +616,40 @@ func genMatch0(w io.Writer, arch arch, match, v string, m map[string]struct{}, l
}
if typ != "" {
if !isVariable(typ) {
// code. We must match the results of this code.
fmt.Fprintf(w, "if %s.Type != %s {\nbreak\n}\n", v, typ)
canFail = true
if !token.IsIdentifier(typ) || rr.declared(typ) {
// code or variable
rr.add(breakf("%s.Type != %s", v, typ))
} else {
// variable
if _, ok := m[typ]; ok {
// must match previous variable
fmt.Fprintf(w, "if %s.Type != %s {\nbreak\n}\n", v, typ)
canFail = true
} else {
m[typ] = struct{}{}
fmt.Fprintf(w, "%s := %s.Type\n", typ, v)
}
rr.add(declf(typ, "%s.Type", v))
}
}
if auxint != "" {
if !isVariable(auxint) {
// code
fmt.Fprintf(w, "if %s.AuxInt != %s {\nbreak\n}\n", v, auxint)
canFail = true
} else {
// variable
if _, ok := m[auxint]; ok {
fmt.Fprintf(w, "if %s.AuxInt != %s {\nbreak\n}\n", v, auxint)
canFail = true
if !token.IsIdentifier(auxint) || rr.declared(auxint) {
// code or variable
rr.add(breakf("%s.AuxInt != %s", v, auxint))
} else {
m[auxint] = struct{}{}
fmt.Fprintf(w, "%s := %s.AuxInt\n", auxint, v)
}
rr.add(declf(auxint, "%s.AuxInt", v))
}
}
if aux != "" {
if !isVariable(aux) {
// code
fmt.Fprintf(w, "if %s.Aux != %s {\nbreak\n}\n", v, aux)
canFail = true
} else {
// variable
if _, ok := m[aux]; ok {
fmt.Fprintf(w, "if %s.Aux != %s {\nbreak\n}\n", v, aux)
canFail = true
if !token.IsIdentifier(aux) || rr.declared(aux) {
// code or variable
rr.add(breakf("%s.Aux != %s", v, aux))
} else {
m[aux] = struct{}{}
fmt.Fprintf(w, "%s := %s.Aux\n", aux, v)
}
rr.add(declf(aux, "%s.Aux", v))
}
}
// Access last argument first to minimize bounds checks.
if n := len(args); n > 1 {
a := args[n-1]
if _, set := m[a]; !set && a != "_" && isVariable(a) {
m[a] = struct{}{}
fmt.Fprintf(w, "%s := %s.Args[%d]\n", a, v, n-1)
if a != "_" && !rr.declared(a) && token.IsIdentifier(a) {
rr.add(declf(a, "%s.Args[%d]", v, n-1))
// delete the last argument so it is not reprocessed
args = args[:n-1]
} else {
fmt.Fprintf(w, "_ = %s.Args[%d]\n", v, n-1)
rr.add(stmtf("_ = %s.Args[%d]", v, n-1))
}
}
for i, arg := range args {
......@@ -491,41 +658,36 @@ func genMatch0(w io.Writer, arch arch, match, v string, m map[string]struct{}, l
}
if !strings.Contains(arg, "(") {
// leaf variable
if _, ok := m[arg]; ok {
if rr.declared(arg) {
// variable already has a definition. Check whether
// the old definition and the new definition match.
// For example, (add x x). Equality is just pointer equality
// on Values (so cse is important to do before lowering).
fmt.Fprintf(w, "if %s != %s.Args[%d] {\nbreak\n}\n", arg, v, i)
canFail = true
rr.add(breakf("%s != %s.Args[%d]", arg, v, i))
} else {
// remember that this variable references the given value
m[arg] = struct{}{}
fmt.Fprintf(w, "%s := %s.Args[%d]\n", arg, v, i)
rr.add(declf(arg, "%s.Args[%d]", v, i))
}
continue
}
// compound sexpr
var argname string
argname := fmt.Sprintf("%s_%d", v, i)
colon := strings.Index(arg, ":")
openparen := strings.Index(arg, "(")
if colon >= 0 && openparen >= 0 && colon < openparen {
// rule-specified name
argname = arg[:colon]
arg = arg[colon+1:]
} else {
// autogenerated name
argname = fmt.Sprintf("%s_%d", v, i)
}
if argname == "b" {
log.Fatalf("don't name args 'b', it is ambiguous with blocks")
}
fmt.Fprintf(w, "%s := %s.Args[%d]\n", argname, v, i)
w2 := new(bytes.Buffer)
argPos, argCheckOp, _ := genMatch0(w2, arch, arg, argname, m, loc)
fmt.Fprintf(w, "if %s.Op != %s {\nbreak\n}\n", argname, argCheckOp)
io.Copy(w, w2)
rr.add(declf(argname, "%s.Args[%d]", v, i))
bexpr := exprf("%s.Op != addLater", argname)
rr.add(&CondBreak{expr: bexpr})
rr.canFail = true // since we're not using breakf
argPos, argCheckOp := genMatch0(rr, arch, arg, argname)
bexpr.(*ast.BinaryExpr).Y.(*ast.Ident).Name = argCheckOp
if argPos != "" {
// Keep the argument in preference to the parent, as the
......@@ -535,28 +697,26 @@ func genMatch0(w io.Writer, arch arch, match, v string, m map[string]struct{}, l
// in the program flow.
pos = argPos
}
canFail = true
}
if op.argLength == -1 {
fmt.Fprintf(w, "if len(%s.Args) != %d {\nbreak\n}\n", v, len(args))
canFail = true
rr.add(breakf("len(%s.Args) != %d", v, len(args)))
}
return pos, checkOp, canFail
return pos, checkOp
}
func genResult(w io.Writer, arch arch, result string, loc string, pos string) {
move := false
if result[0] == '@' {
func genResult(rr *RuleRewrite, arch arch, result, pos string) {
move := result[0] == '@'
if move {
// parse @block directive
s := strings.SplitN(result[1:], " ", 2)
fmt.Fprintf(w, "b = %s\n", s[0])
rr.add(stmtf("b = %s", s[0]))
result = s[1]
move = true
}
genResult0(w, arch, result, new(int), true, move, loc, pos)
genResult0(rr, arch, result, true, move, pos)
}
func genResult0(w io.Writer, arch arch, result string, alloc *int, top, move bool, loc string, pos string) string {
func genResult0(rr *RuleRewrite, arch arch, result string, top, move bool, pos string) string {
// TODO: when generating a constant result, use f.constVal to avoid
// introducing copies just to clean them up again.
if result[0] != '(' {
......@@ -565,14 +725,14 @@ func genResult0(w io.Writer, arch arch, result string, alloc *int, top, move boo
// It in not safe in general to move a variable between blocks
// (and particularly not a phi node).
// Introduce a copy.
fmt.Fprintf(w, "v.reset(OpCopy)\n")
fmt.Fprintf(w, "v.Type = %s.Type\n", result)
fmt.Fprintf(w, "v.AddArg(%s)\n", result)
rr.add(stmtf("v.reset(OpCopy)"))
rr.add(stmtf("v.Type = %s.Type", result))
rr.add(stmtf("v.AddArg(%s)", result))
}
return result
}
op, oparch, typ, auxint, aux, args := parseValue(result, arch, loc)
op, oparch, typ, auxint, aux, args := parseValue(result, arch, rr.loc)
// Find the type of the variable.
typeOverride := typ != ""
......@@ -580,36 +740,35 @@ func genResult0(w io.Writer, arch arch, result string, alloc *int, top, move boo
typ = typeName(op.typ)
}
var v string
v := "v"
if top && !move {
v = "v"
fmt.Fprintf(w, "v.reset(Op%s%s)\n", oparch, op.name)
rr.add(stmtf("v.reset(Op%s%s)", oparch, op.name))
if typeOverride {
fmt.Fprintf(w, "v.Type = %s\n", typ)
rr.add(stmtf("v.Type = %s", typ))
}
} else {
if typ == "" {
log.Fatalf("sub-expression %s (op=Op%s%s) at %s must have a type", result, oparch, op.name, loc)
log.Fatalf("sub-expression %s (op=Op%s%s) at %s must have a type", result, oparch, op.name, rr.loc)
}
v = fmt.Sprintf("v%d", *alloc)
*alloc++
fmt.Fprintf(w, "%s := b.NewValue0(%s, Op%s%s, %s)\n", v, pos, oparch, op.name, typ)
v = fmt.Sprintf("v%d", rr.alloc)
rr.alloc++
rr.add(declf(v, "b.NewValue0(%s, Op%s%s, %s)", pos, oparch, op.name, typ))
if move && top {
// Rewrite original into a copy
fmt.Fprintf(w, "v.reset(OpCopy)\n")
fmt.Fprintf(w, "v.AddArg(%s)\n", v)
rr.add(stmtf("v.reset(OpCopy)"))
rr.add(stmtf("v.AddArg(%s)", v))
}
}
if auxint != "" {
fmt.Fprintf(w, "%s.AuxInt = %s\n", v, auxint)
rr.add(stmtf("%s.AuxInt = %s", v, auxint))
}
if aux != "" {
fmt.Fprintf(w, "%s.Aux = %s\n", v, aux)
rr.add(stmtf("%s.Aux = %s", v, aux))
}
for _, arg := range args {
x := genResult0(w, arch, arg, alloc, false, move, loc, pos)
fmt.Fprintf(w, "%s.AddArg(%s)\n", v, x)
x := genResult0(rr, arch, arg, false, move, pos)
rr.add(stmtf("%s.AddArg(%s)", v, x))
}
return v
......@@ -652,7 +811,7 @@ outer:
}
}
if d != 0 {
panic("imbalanced expression: " + s)
log.Fatalf("imbalanced expression: %q", s)
}
if nonsp {
r = append(r, strings.TrimSpace(s))
......@@ -677,7 +836,7 @@ func isBlock(name string, arch arch) bool {
return false
}
func extract(val string) (op string, typ string, auxint string, aux string, args []string) {
func extract(val string) (op, typ, auxint, aux string, args []string) {
val = val[1 : len(val)-1] // remove ()
// Split val up into regions.
......@@ -705,7 +864,7 @@ func extract(val string) (op string, typ string, auxint string, aux string, args
// The value can be from the match or the result side.
// It returns the op and unparsed strings for typ, auxint, and aux restrictions and for all args.
// oparch is the architecture that op is located in, or "" for generic.
func parseValue(val string, arch arch, loc string) (op opData, oparch string, typ string, auxint string, aux string, args []string) {
func parseValue(val string, arch arch, loc string) (op opData, oparch, typ, auxint, aux string, args []string) {
// Resolve the op.
var s string
s, typ, auxint, aux, args = extract(val)
......@@ -723,9 +882,8 @@ func parseValue(val string, arch arch, loc string) (op opData, oparch string, ty
if x.argLength != -1 && int(x.argLength) != len(args) {
if strict {
return false
} else {
log.Printf("%s: op %s (%s) should have %d args, has %d", loc, s, archname, x.argLength, len(args))
}
log.Printf("%s: op %s (%s) should have %d args, has %d", loc, s, archname, x.argLength, len(args))
}
return true
}
......@@ -736,9 +894,8 @@ func parseValue(val string, arch arch, loc string) (op opData, oparch string, ty
break
}
}
if arch.name != "generic" {
for _, x := range arch.ops {
if match(x, true, arch.name) {
if arch.name != "generic" && match(x, true, arch.name) {
if op.name != "" {
log.Fatalf("%s: matches for op %s found in both generic and %s", loc, op.name, arch.name)
}
......@@ -747,7 +904,6 @@ func parseValue(val string, arch arch, loc string) (op opData, oparch string, ty
break
}
}
}
if op.name == "" {
// Failed to find the op.
......@@ -777,7 +933,6 @@ func parseValue(val string, arch arch, loc string) (op opData, oparch string, ty
log.Fatalf("%s: op %s %s can't have aux", loc, op.name, op.aux)
}
}
return
}
......@@ -795,7 +950,7 @@ func typeName(typ string) string {
if typ[0] == '(' {
ts := strings.Split(typ[1:len(typ)-1], ",")
if len(ts) != 2 {
panic("Tuple expect 2 arguments")
log.Fatalf("Tuple expect 2 arguments")
}
return "types.NewTuple(" + typeName(ts[0]) + ", " + typeName(ts[1]) + ")"
}
......@@ -809,29 +964,19 @@ func typeName(typ string) string {
// unbalanced reports whether there aren't the same number of ( and ) in the string.
func unbalanced(s string) bool {
var left, right int
balance := 0
for _, c := range s {
if c == '(' {
left++
}
if c == ')' {
right++
balance++
} else if c == ')' {
balance--
}
}
return left != right
return balance != 0
}
// isVariable reports whether s is a single Go alphanumeric identifier.
func isVariable(s string) bool {
b, err := regexp.MatchString("^[A-Za-z_][A-Za-z_0-9]*$", s)
if err != nil {
panic("bad variable regexp")
}
return b
}
// opRegexp is a regular expression to find the opcode portion of s-expressions.
var opRegexp = regexp.MustCompile(`[(](\w+[|])+\w+[)]`)
// findAllOpcode is a function to find the opcode portion of s-expressions.
var findAllOpcode = regexp.MustCompile(`[(](\w+[|])+\w+[)]`).FindAllStringIndex
// excludeFromExpansion reports whether the substring s[idx[0]:idx[1]] in a rule
// should be disregarded as a candidate for | expansion.
......@@ -859,7 +1004,7 @@ func expandOr(r string) []string {
// Count width of |-forms. They must match.
n := 1
for _, idx := range opRegexp.FindAllStringIndex(r, -1) {
for _, idx := range findAllOpcode(r, -1) {
if excludeFromExpansion(r, idx) {
continue
}
......@@ -882,7 +1027,7 @@ func expandOr(r string) []string {
for i := 0; i < n; i++ {
buf := new(strings.Builder)
x := 0
for _, idx := range opRegexp.FindAllStringIndex(r, -1) {
for _, idx := range findAllOpcode(r, -1) {
if excludeFromExpansion(r, idx) {
continue
}
......@@ -913,7 +1058,7 @@ func commute(r string, arch arch) []string {
if len(a) == 1 && normalizeWhitespace(r) != normalizeWhitespace(a[0]) {
fmt.Println(normalizeWhitespace(r))
fmt.Println(normalizeWhitespace(a[0]))
panic("commute() is not the identity for noncommuting rule")
log.Fatalf("commute() is not the identity for noncommuting rule")
}
if false && len(a) > 1 {
fmt.Println(r)
......@@ -925,18 +1070,17 @@ func commute(r string, arch arch) []string {
}
func commute1(m string, cnt map[string]int, arch arch) []string {
if m[0] == '<' || m[0] == '[' || m[0] == '{' || isVariable(m) {
if m[0] == '<' || m[0] == '[' || m[0] == '{' || token.IsIdentifier(m) {
return []string{m}
}
// Split up input.
var prefix string
colon := strings.Index(m, ":")
if colon >= 0 && isVariable(m[:colon]) {
prefix = m[:colon+1]
m = m[colon+1:]
if i := strings.Index(m, ":"); i >= 0 && token.IsIdentifier(m[:i]) {
prefix = m[:i+1]
m = m[i+1:]
}
if m[0] != '(' || m[len(m)-1] != ')' {
panic("non-compound expr in commute1: " + m)
log.Fatalf("non-compound expr in commute1: %q", m)
}
s := split(m[1 : len(m)-1])
op := s[0]
......@@ -978,7 +1122,7 @@ func commute1(m string, cnt map[string]int, arch arch) []string {
}
}
if idx1 == 0 {
panic("couldn't find first two args of commutative op " + s[0])
log.Fatalf("couldn't find first two args of commutative op %q", s[0])
}
if cnt[s[idx0]] == 1 && cnt[s[idx1]] == 1 || s[idx0] == s[idx1] && cnt[s[idx0]] == 2 {
// When we have (Add x y) with no other uses of x and y in the matching rule,
......@@ -1016,22 +1160,22 @@ func varCount(m string) map[string]int {
varCount1(m, cnt)
return cnt
}
func varCount1(m string, cnt map[string]int) {
if m[0] == '<' || m[0] == '[' || m[0] == '{' {
return
}
if isVariable(m) {
if token.IsIdentifier(m) {
cnt[m]++
return
}
// Split up input.
colon := strings.Index(m, ":")
if colon >= 0 && isVariable(m[:colon]) {
cnt[m[:colon]]++
m = m[colon+1:]
if i := strings.Index(m, ":"); i >= 0 && token.IsIdentifier(m[:i]) {
cnt[m[:i]]++
m = m[i+1:]
}
if m[0] != '(' || m[len(m)-1] != ')' {
panic("non-compound expr in commute1: " + m)
log.Fatalf("non-compound expr in commute1: %q", m)
}
s := split(m[1 : len(m)-1])
for _, arg := range s[1:] {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment