// Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // +build gen // This program generates Go code that applies rewrite rules to a Value. // The generated code implements a function of type func (v *Value) bool // which returns true iff if did something. // Ideas stolen from Swift: http://www.hpl.hp.com/techreports/Compaq-DEC/WRL-2000-2.html package main import ( "bufio" "bytes" "flag" "fmt" "go/format" "io" "io/ioutil" "log" "os" "regexp" "sort" "strings" ) // rule syntax: // sexpr [&& extra conditions] -> [@block] sexpr // // sexpr are s-expressions (lisp-like parenthesized groupings) // sexpr ::= (opcode sexpr*) // | variable // | <type> // | [auxint] // | {aux} // // aux ::= variable | {code} // type ::= variable | {code} // variable ::= some token // opcode ::= one of the opcodes from ../op.go (without the Op prefix) // extra conditions is just a chunk of Go that evaluates to a boolean. It may use // variables declared in the matching sexpr. The variable "v" is predefined to be // the value matched by the entire rule. // If multiple rules match, the first one in file order is selected. var ( genLog = flag.Bool("log", false, "generate code that logs; for debugging only") ) type Rule struct { rule string loc string // file name & line number } func (r Rule) String() string { return fmt.Sprintf("rule %q at %s", r.rule, r.loc) } // parse returns the matching part of the rule, additional conditions, and the result. func (r Rule) parse() (match, cond, result string) { s := strings.Split(r.rule, "->") if len(s) != 2 { log.Fatalf("no arrow in %s", r) } match = strings.TrimSpace(s[0]) result = strings.TrimSpace(s[1]) cond = "" if i := strings.Index(match, "&&"); i >= 0 { cond = strings.TrimSpace(match[i+2:]) match = strings.TrimSpace(match[:i]) } return match, cond, result } func genRules(arch arch) { // Open input file. text, err := os.Open(arch.name + ".rules") if err != nil { log.Fatalf("can't read rule file: %v", err) } // oprules contains a list of rules for each block and opcode blockrules := map[string][]Rule{} oprules := map[string][]Rule{} // read rule file scanner := bufio.NewScanner(text) rule := "" var lineno int for scanner.Scan() { lineno++ line := scanner.Text() if i := strings.Index(line, "//"); i >= 0 { // Remove comments. Note that this isn't string safe, so // it will truncate lines with // inside strings. Oh well. line = line[:i] } rule += " " + line rule = strings.TrimSpace(rule) if rule == "" { continue } if !strings.Contains(rule, "->") { continue } if strings.HasSuffix(rule, "->") { continue } if unbalanced(rule) { continue } op := strings.Split(rule, " ")[0][1:] if op[len(op)-1] == ')' { op = op[:len(op)-1] // rule has only opcode, e.g. (ConstNil) -> ... } loc := fmt.Sprintf("%s.rules:%d", arch.name, lineno) if isBlock(op, arch) { blockrules[op] = append(blockrules[op], Rule{rule: rule, loc: loc}) } else { oprules[op] = append(oprules[op], Rule{rule: rule, loc: loc}) } rule = "" } if err := scanner.Err(); err != nil { log.Fatalf("scanner failed: %v\n", err) } if unbalanced(rule) { log.Fatalf("%s.rules:%d: unbalanced rule: %v\n", arch.name, lineno, rule) } // Order all the ops. var ops []string for op := range oprules { ops = append(ops, op) } sort.Strings(ops) // Start output buffer, write header. w := new(bytes.Buffer) fmt.Fprintf(w, "// autogenerated from gen/%s.rules: do not edit!\n", arch.name) fmt.Fprintln(w, "// generated with: cd gen; go run *.go") fmt.Fprintln(w) fmt.Fprintln(w, "package ssa") if *genLog { fmt.Fprintln(w, "import \"fmt\"") } fmt.Fprintln(w, "import \"math\"") fmt.Fprintln(w, "var _ = math.MinInt8 // in case not otherwise used") // Main rewrite routine is a switch on v.Op. fmt.Fprintf(w, "func rewriteValue%s(v *Value, config *Config) bool {\n", arch.name) fmt.Fprintf(w, "switch v.Op {\n") for _, op := range ops { fmt.Fprintf(w, "case %s:\n", opName(op, arch)) fmt.Fprintf(w, "return rewriteValue%s_%s(v, config)\n", arch.name, opName(op, arch)) } fmt.Fprintf(w, "}\n") fmt.Fprintf(w, "return false\n") fmt.Fprintf(w, "}\n") // Generate a routine per op. Note that we don't make one giant routine // because it is too big for some compilers. for _, op := range ops { fmt.Fprintf(w, "func rewriteValue%s_%s(v *Value, config *Config) bool {\n", arch.name, opName(op, arch)) fmt.Fprintln(w, "b := v.Block") fmt.Fprintln(w, "_ = b") for _, rule := range oprules[op] { match, cond, result := rule.parse() fmt.Fprintf(w, "// match: %s\n", match) fmt.Fprintf(w, "// cond: %s\n", cond) fmt.Fprintf(w, "// result: %s\n", result) fmt.Fprintf(w, "for {\n") genMatch(w, arch, match, rule.loc) if cond != "" { fmt.Fprintf(w, "if !(%s) {\nbreak\n}\n", cond) } genResult(w, arch, result, rule.loc) if *genLog { fmt.Fprintf(w, "fmt.Println(\"rewrite %s\")\n", rule.loc) } fmt.Fprintf(w, "return true\n") fmt.Fprintf(w, "}\n") } fmt.Fprintf(w, "return false\n") fmt.Fprintf(w, "}\n") } // Generate block rewrite function. There are only a few block types // so we can make this one function with a switch. fmt.Fprintf(w, "func rewriteBlock%s(b *Block) bool {\n", arch.name) fmt.Fprintf(w, "switch b.Kind {\n") ops = nil for op := range blockrules { ops = append(ops, op) } sort.Strings(ops) for _, op := range ops { fmt.Fprintf(w, "case %s:\n", blockName(op, arch)) for _, rule := range blockrules[op] { match, cond, result := rule.parse() fmt.Fprintf(w, "// match: %s\n", match) fmt.Fprintf(w, "// cond: %s\n", cond) fmt.Fprintf(w, "// result: %s\n", result) fmt.Fprintf(w, "for {\n") s := split(match[1 : len(match)-1]) // remove parens, then split // check match of control value if s[1] != "nil" { fmt.Fprintf(w, "v := b.Control\n") if strings.Contains(s[1], "(") { genMatch0(w, arch, s[1], "v", map[string]struct{}{}, false, rule.loc) } else { fmt.Fprintf(w, "%s := b.Control\n", s[1]) } } // assign successor names succs := s[2:] for i, a := range succs { if a != "_" { fmt.Fprintf(w, "%s := b.Succs[%d]\n", a, i) } } if cond != "" { fmt.Fprintf(w, "if !(%s) {\nbreak\n}\n", cond) } // Rule matches. Generate result. t := split(result[1 : len(result)-1]) // remove parens, then split newsuccs := t[2:] // Check if newsuccs is the same set as succs. m := map[string]bool{} for _, succ := range succs { if m[succ] { log.Fatalf("can't have a repeat successor name %s in %s", succ, rule) } m[succ] = true } for _, succ := range newsuccs { if !m[succ] { log.Fatalf("unknown successor %s in %s", succ, rule) } delete(m, succ) } if len(m) != 0 { log.Fatalf("unmatched successors %v in %s", m, rule) } // Modify predecessor lists for no-longer-reachable blocks for succ := range m { fmt.Fprintf(w, "b.Func.removePredecessor(b, %s)\n", succ) } fmt.Fprintf(w, "b.Kind = %s\n", blockName(t[0], arch)) if t[1] == "nil" { fmt.Fprintf(w, "b.SetControl(nil)\n") } else { fmt.Fprintf(w, "b.SetControl(%s)\n", genResult0(w, arch, t[1], new(int), false, false, rule.loc)) } if len(newsuccs) < len(succs) { fmt.Fprintf(w, "b.Succs = b.Succs[:%d]\n", len(newsuccs)) } for i, a := range newsuccs { fmt.Fprintf(w, "b.Succs[%d] = %s\n", i, a) } // Update branch prediction switch { case len(newsuccs) != 2: fmt.Fprintln(w, "b.Likely = BranchUnknown") case newsuccs[0] == succs[0] && newsuccs[1] == succs[1]: // unchanged case newsuccs[0] == succs[1] && newsuccs[1] == succs[0]: // flipped fmt.Fprintln(w, "b.Likely *= -1") default: // unknown fmt.Fprintln(w, "b.Likely = BranchUnknown") } if *genLog { fmt.Fprintf(w, "fmt.Println(\"rewrite %s\")\n", rule.loc) } fmt.Fprintf(w, "return true\n") fmt.Fprintf(w, "}\n") } } fmt.Fprintf(w, "}\n") fmt.Fprintf(w, "return false\n") fmt.Fprintf(w, "}\n") // gofmt result b := w.Bytes() src, err := format.Source(b) if err != nil { fmt.Printf("%s\n", b) panic(err) } // Write to file err = ioutil.WriteFile("../rewrite"+arch.name+".go", src, 0666) if err != nil { log.Fatalf("can't write output: %v\n", err) } } func genMatch(w io.Writer, arch arch, match string, loc string) { genMatch0(w, arch, match, "v", map[string]struct{}{}, true, loc) } func genMatch0(w io.Writer, arch arch, match, v string, m map[string]struct{}, top bool, loc string) { if match[0] != '(' || match[len(match)-1] != ')' { panic("non-compound expr in genMatch0: " + match) } // split body up into regions. Split by spaces/tabs, except those // contained in () or {}. s := split(match[1 : len(match)-1]) // remove parens, then split // Find op record var op opData for _, x := range genericOps { if x.name == s[0] { op = x break } } for _, x := range arch.ops { if x.name == s[0] { op = x break } } if op.name == "" { log.Fatalf("%s: unknown op %s", loc, s[0]) } // check op if !top { fmt.Fprintf(w, "if %s.Op != %s {\nbreak\n}\n", v, opName(s[0], arch)) } // check type/aux/args argnum := 0 for _, a := range s[1:] { if a[0] == '<' { // type restriction t := a[1 : len(a)-1] // remove <> if !isVariable(t) { // code. We must match the results of this code. fmt.Fprintf(w, "if %s.Type != %s {\nbreak\n}\n", v, t) } else { // variable if _, ok := m[t]; ok { // must match previous variable fmt.Fprintf(w, "if %s.Type != %s {\nbreak\n}\n", v, t) } else { m[t] = struct{}{} fmt.Fprintf(w, "%s := %s.Type\n", t, v) } } } else if a[0] == '[' { // auxint restriction switch op.aux { case "Bool", "Int8", "Int16", "Int32", "Int64", "Int128", "Float32", "Float64", "SymOff", "SymValAndOff", "SymInt32": default: log.Fatalf("%s: op %s %s can't have auxint", loc, op.name, op.aux) } x := a[1 : len(a)-1] // remove [] if !isVariable(x) { // code fmt.Fprintf(w, "if %s.AuxInt != %s {\nbreak\n}\n", v, x) } else { // variable if _, ok := m[x]; ok { fmt.Fprintf(w, "if %s.AuxInt != %s {\nbreak\n}\n", v, x) } else { m[x] = struct{}{} fmt.Fprintf(w, "%s := %s.AuxInt\n", x, v) } } } else if a[0] == '{' { // aux restriction switch op.aux { case "String", "Sym", "SymOff", "SymValAndOff", "SymInt32": default: log.Fatalf("%s: op %s %s can't have aux", loc, op.name, op.aux) } x := a[1 : len(a)-1] // remove {} if !isVariable(x) { // code fmt.Fprintf(w, "if %s.Aux != %s {\nbreak\n}\n", v, x) } else { // variable if _, ok := m[x]; ok { fmt.Fprintf(w, "if %s.Aux != %s {\nbreak\n}\n", v, x) } else { m[x] = struct{}{} fmt.Fprintf(w, "%s := %s.Aux\n", x, v) } } } else if a == "_" { argnum++ } else if !strings.Contains(a, "(") { // leaf variable if _, ok := m[a]; ok { // variable already has a definition. Check whether // the old definition and the new definition match. // For example, (add x x). Equality is just pointer equality // on Values (so cse is important to do before lowering). fmt.Fprintf(w, "if %s != %s.Args[%d] {\nbreak\n}\n", a, v, argnum) } else { // remember that this variable references the given value m[a] = struct{}{} fmt.Fprintf(w, "%s := %s.Args[%d]\n", a, v, argnum) } argnum++ } else { // compound sexpr var argname string colon := strings.Index(a, ":") openparen := strings.Index(a, "(") if colon >= 0 && openparen >= 0 && colon < openparen { // rule-specified name argname = a[:colon] a = a[colon+1:] } else { // autogenerated name argname = fmt.Sprintf("%s_%d", v, argnum) } fmt.Fprintf(w, "%s := %s.Args[%d]\n", argname, v, argnum) genMatch0(w, arch, a, argname, m, false, loc) argnum++ } } if op.argLength == -1 { fmt.Fprintf(w, "if len(%s.Args) != %d {\nbreak\n}\n", v, argnum) } else if int(op.argLength) != argnum { log.Fatalf("%s: op %s should have %d args, has %d", loc, op.name, op.argLength, argnum) } } func genResult(w io.Writer, arch arch, result string, loc string) { move := false if result[0] == '@' { // parse @block directive s := strings.SplitN(result[1:], " ", 2) fmt.Fprintf(w, "b = %s\n", s[0]) result = s[1] move = true } genResult0(w, arch, result, new(int), true, move, loc) } func genResult0(w io.Writer, arch arch, result string, alloc *int, top, move bool, loc string) string { // TODO: when generating a constant result, use f.constVal to avoid // introducing copies just to clean them up again. if result[0] != '(' { // variable if top { // It in not safe in general to move a variable between blocks // (and particularly not a phi node). // Introduce a copy. fmt.Fprintf(w, "v.reset(OpCopy)\n") fmt.Fprintf(w, "v.Type = %s.Type\n", result) fmt.Fprintf(w, "v.AddArg(%s)\n", result) } return result } s := split(result[1 : len(result)-1]) // remove parens, then split // Find op record var op opData for _, x := range genericOps { if x.name == s[0] { op = x break } } for _, x := range arch.ops { if x.name == s[0] { op = x break } } if op.name == "" { log.Fatalf("%s: unknown op %s", loc, s[0]) } // Find the type of the variable. var opType string var typeOverride bool for _, a := range s[1:] { if a[0] == '<' { // type restriction opType = a[1 : len(a)-1] // remove <> typeOverride = true break } } if opType == "" { // find default type, if any for _, op := range arch.ops { if op.name == s[0] && op.typ != "" { opType = typeName(op.typ) break } } } if opType == "" { for _, op := range genericOps { if op.name == s[0] && op.typ != "" { opType = typeName(op.typ) break } } } var v string if top && !move { v = "v" fmt.Fprintf(w, "v.reset(%s)\n", opName(s[0], arch)) if typeOverride { fmt.Fprintf(w, "v.Type = %s\n", opType) } } else { if opType == "" { log.Fatalf("sub-expression %s (op=%s) must have a type", result, s[0]) } v = fmt.Sprintf("v%d", *alloc) *alloc++ fmt.Fprintf(w, "%s := b.NewValue0(v.Line, %s, %s)\n", v, opName(s[0], arch), opType) if move && top { // Rewrite original into a copy fmt.Fprintf(w, "v.reset(OpCopy)\n") fmt.Fprintf(w, "v.AddArg(%s)\n", v) } } argnum := 0 for _, a := range s[1:] { if a[0] == '<' { // type restriction, handled above } else if a[0] == '[' { // auxint restriction switch op.aux { case "Bool", "Int8", "Int16", "Int32", "Int64", "Int128", "Float32", "Float64", "SymOff", "SymValAndOff", "SymInt32": default: log.Fatalf("%s: op %s %s can't have auxint", loc, op.name, op.aux) } x := a[1 : len(a)-1] // remove [] fmt.Fprintf(w, "%s.AuxInt = %s\n", v, x) } else if a[0] == '{' { // aux restriction switch op.aux { case "String", "Sym", "SymOff", "SymValAndOff", "SymInt32": default: log.Fatalf("%s: op %s %s can't have aux", loc, op.name, op.aux) } x := a[1 : len(a)-1] // remove {} fmt.Fprintf(w, "%s.Aux = %s\n", v, x) } else { // regular argument (sexpr or variable) x := genResult0(w, arch, a, alloc, false, move, loc) fmt.Fprintf(w, "%s.AddArg(%s)\n", v, x) argnum++ } } if op.argLength != -1 && int(op.argLength) != argnum { log.Fatalf("%s: op %s should have %d args, has %d", loc, op.name, op.argLength, argnum) } return v } func split(s string) []string { var r []string outer: for s != "" { d := 0 // depth of ({[< var open, close byte // opening and closing markers ({[< or )}]> nonsp := false // found a non-space char so far for i := 0; i < len(s); i++ { switch { case d == 0 && s[i] == '(': open, close = '(', ')' d++ case d == 0 && s[i] == '<': open, close = '<', '>' d++ case d == 0 && s[i] == '[': open, close = '[', ']' d++ case d == 0 && s[i] == '{': open, close = '{', '}' d++ case d == 0 && (s[i] == ' ' || s[i] == '\t'): if nonsp { r = append(r, strings.TrimSpace(s[:i])) s = s[i:] continue outer } case d > 0 && s[i] == open: d++ case d > 0 && s[i] == close: d-- default: nonsp = true } } if d != 0 { panic("imbalanced expression: " + s) } if nonsp { r = append(r, strings.TrimSpace(s)) } break } return r } // isBlock returns true if this op is a block opcode. func isBlock(name string, arch arch) bool { for _, b := range genericBlocks { if b.name == name { return true } } for _, b := range arch.blocks { if b.name == name { return true } } return false } // opName converts from an op name specified in a rule file to an Op enum. // if the name matches a generic op, returns "Op" plus the specified name. // Otherwise, returns "Op" plus arch name plus op name. func opName(name string, arch arch) string { for _, op := range genericOps { if op.name == name { return "Op" + name } } return "Op" + arch.name + name } func blockName(name string, arch arch) string { for _, b := range genericBlocks { if b.name == name { return "Block" + name } } return "Block" + arch.name + name } // typeName returns the string to use to generate a type. func typeName(typ string) string { switch typ { case "Flags", "Mem", "Void", "Int128": return "Type" + typ default: return "config.fe.Type" + typ + "()" } } // unbalanced returns true if there aren't the same number of ( and ) in the string. func unbalanced(s string) bool { var left, right int for _, c := range s { if c == '(' { left++ } if c == ')' { right++ } } return left != right } // isVariable reports whether s is a single Go alphanumeric identifier. func isVariable(s string) bool { b, err := regexp.MatchString("^[A-Za-z_][A-Za-z_0-9]*$", s) if err != nil { panic("bad variable regexp") } return b }