Commit 877c1892 authored by Russ Cox's avatar Russ Cox

gofix: add -diff, various fixes and helpers

  * add -diff command line option
  * use scoping information in refersTo, isPkgDot, isPtrPkgDot.
  * add new scoping-based helpers countUses, rewriteUses, assignsTo, isTopName.
  * rename rewrite to walk, add walkBeforeAfter.
  * add toy typechecker, a placeholder for go/types

R=gri
CC=golang-dev
https://golang.org/cl/4285053
parent fb175cf7
This diff is collapsed.
......@@ -41,7 +41,7 @@ func httpserver(f *ast.File) bool {
if !ok {
continue
}
rewrite(fn.Body, func(n interface{}) {
walk(fn.Body, func(n interface{}) {
// Want to replace expression sometimes,
// so record pointer to it for updating below.
ptr, ok := n.(*ast.Expr)
......
......@@ -6,6 +6,7 @@ package main
import (
"bytes"
"exec"
"flag"
"fmt"
"go/parser"
......@@ -29,8 +30,10 @@ var allowedRewrites = flag.String("r", "",
var allowed map[string]bool
var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
func usage() {
fmt.Fprintf(os.Stderr, "usage: gofix [-r fixname,...] [path ...]\n")
fmt.Fprintf(os.Stderr, "usage: gofix [-diff] [-r fixname,...] [path ...]\n")
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
for _, f := range fixes {
......@@ -85,10 +88,16 @@ const (
printerMode = printer.TabIndent | printer.UseSpaces
)
var printConfig = &printer.Config{
printerMode,
tabWidth,
}
func processFile(filename string, useStdin bool) os.Error {
var f *os.File
var err os.Error
var fixlog bytes.Buffer
var buf bytes.Buffer
if useStdin {
f = os.Stdin
......@@ -110,34 +119,77 @@ func processFile(filename string, useStdin bool) os.Error {
return err
}
// Apply all fixes to file.
newFile := file
fixed := false
var buf bytes.Buffer
for _, fix := range fixes {
if allowed != nil && !allowed[fix.desc] {
continue
}
if fix.f(file) {
if fix.f(newFile) {
fixed = true
fmt.Fprintf(&buf, " %s", fix.name)
fmt.Fprintf(&fixlog, " %s", fix.name)
// AST changed.
// Print and parse, to update any missing scoping
// or position information for subsequent fixers.
buf.Reset()
_, err = printConfig.Fprint(&buf, fset, newFile)
if err != nil {
return err
}
newSrc := buf.Bytes()
newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
if err != nil {
return err
}
}
}
if !fixed {
return nil
}
fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, buf.String()[1:])
fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
// Print AST. We did that after each fix, so this appears
// redundant, but it is necessary to generate gofmt-compatible
// source code in a few cases. The official gofmt style is the
// output of the printer run on a standard AST generated by the parser,
// but the source we generated inside the loop above is the
// output of the printer run on a mangled AST generated by a fixer.
buf.Reset()
_, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file)
_, err = printConfig.Fprint(&buf, fset, newFile)
if err != nil {
return err
}
newSrc := buf.Bytes()
if *doDiff {
data, err := diff(src, newSrc)
if err != nil {
return fmt.Errorf("computing diff: %s", err)
}
fmt.Printf("diff %s fixed/%s\n", filename, filename)
os.Stdout.Write(data)
return nil
}
if useStdin {
os.Stdout.Write(buf.Bytes())
os.Stdout.Write(newSrc)
return nil
}
return ioutil.WriteFile(f.Name(), buf.Bytes(), 0)
return ioutil.WriteFile(f.Name(), newSrc, 0)
}
var gofmtBuf bytes.Buffer
func gofmt(n interface{}) string {
gofmtBuf.Reset()
_, err := printConfig.Fprint(&gofmtBuf, fset, n)
if err != nil {
return "<" + err.String() + ">"
}
return gofmtBuf.String()
}
func report(err os.Error) {
......@@ -177,3 +229,36 @@ func isGoFile(f *os.FileInfo) bool {
// ignore non-Go files
return f.IsRegular() && !strings.HasPrefix(f.Name, ".") && strings.HasSuffix(f.Name, ".go")
}
func diff(b1, b2 []byte) (data []byte, err os.Error) {
f1, err := ioutil.TempFile("", "gofix")
if err != nil {
return nil, err
}
defer os.Remove(f1.Name())
defer f1.Close()
f2, err := ioutil.TempFile("", "gofix")
if err != nil {
return nil, err
}
defer os.Remove(f2.Name())
defer f2.Close()
f1.Write(b1)
f2.Write(b2)
diffcmd, err := exec.LookPath("diff")
if err != nil {
return nil, err
}
c, err := exec.Run(diffcmd, []string{"diff", f1.Name(), f2.Name()}, nil, "",
exec.DevNull, exec.Pipe, exec.MergeWithStdout)
if err != nil {
return nil, err
}
defer c.Close()
return ioutil.ReadAll(c.Stdout)
}
......@@ -6,12 +6,10 @@ package main
import (
"bytes"
"exec"
"go/ast"
"go/parser"
"go/printer"
"io/ioutil"
"os"
"strings"
"testing"
)
......@@ -28,6 +26,8 @@ func addTestCases(t []testCase) {
testCases = append(testCases, t...)
}
func fnop(*ast.File) bool { return false }
func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out string, fixed, ok bool) {
file, err := parser.ParseFile(fset, desc, in, parserMode)
if err != nil {
......@@ -42,7 +42,7 @@ func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out
t.Errorf("%s: printing: %v", desc, err)
return
}
if s := buf.String(); in != s {
if s := buf.String(); in != s && fn != fnop {
t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
desc, desc, in, desc, s)
tdiff(t, in, s)
......@@ -77,8 +77,17 @@ func TestRewrite(t *testing.T) {
continue
}
// reformat to get printing right
out, _, ok = parseFixPrint(t, fnop, tt.Name, out)
if !ok {
continue
}
if out != tt.Out {
t.Errorf("%s: incorrect output.\n--- have\n%s\n--- want\n%s", tt.Name, out, tt.Out)
t.Errorf("%s: incorrect output.\n", tt.Name)
if !strings.HasPrefix(tt.Name, "testdata/") {
t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out)
}
tdiff(t, out, tt.Out)
continue
}
......@@ -108,44 +117,10 @@ func TestRewrite(t *testing.T) {
}
func tdiff(t *testing.T, a, b string) {
f1, err := ioutil.TempFile("", "gofix")
if err != nil {
t.Error(err)
return
}
defer os.Remove(f1.Name())
defer f1.Close()
f2, err := ioutil.TempFile("", "gofix")
if err != nil {
t.Error(err)
return
}
defer os.Remove(f2.Name())
defer f2.Close()
f1.Write([]byte(a))
f2.Write([]byte(b))
diffcmd, err := exec.LookPath("diff")
if err != nil {
t.Error(err)
return
}
c, err := exec.Run(diffcmd, []string{"diff", f1.Name(), f2.Name()}, nil, "",
exec.DevNull, exec.Pipe, exec.MergeWithStdout)
data, err := diff([]byte(a), []byte(b))
if err != nil {
t.Error(err)
return
}
defer c.Close()
data, err := ioutil.ReadAll(c.Stdout)
if err != nil {
t.Error(err)
return
}
t.Error(string(data))
}
......@@ -47,7 +47,7 @@ func netdial(f *ast.File) bool {
}
fixed := false
rewrite(f, func(n interface{}) {
walk(f, func(n interface{}) {
call, ok := n.(*ast.CallExpr)
if !ok || !isPkgDot(call.Fun, "net", "Dial") || len(call.Args) != 3 {
return
......@@ -70,7 +70,7 @@ func tlsdial(f *ast.File) bool {
}
fixed := false
rewrite(f, func(n interface{}) {
walk(f, func(n interface{}) {
call, ok := n.(*ast.CallExpr)
if !ok || !isPkgDot(call.Fun, "tls", "Dial") || len(call.Args) != 4 {
return
......@@ -94,7 +94,7 @@ func netlookup(f *ast.File) bool {
}
fixed := false
rewrite(f, func(n interface{}) {
walk(f, func(n interface{}) {
as, ok := n.(*ast.AssignStmt)
if !ok || len(as.Lhs) != 3 || len(as.Rhs) != 1 {
return
......
......@@ -27,7 +27,7 @@ func osopen(f *ast.File) bool {
}
fixed := false
rewrite(f, func(n interface{}) {
walk(f, func(n interface{}) {
// Rename O_CREAT to O_CREATE.
if expr, ok := n.(ast.Expr); ok && isPkgDot(expr, "os", "O_CREAT") {
expr.(*ast.SelectorExpr).Sel.Name = "O_CREATE"
......
......@@ -28,7 +28,7 @@ func procattr(f *ast.File) bool {
}
fixed := false
rewrite(f, func(n interface{}) {
walk(f, func(n interface{}) {
call, ok := n.(*ast.CallExpr)
if !ok || len(call.Args) != 5 {
return
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment