Commit a34f1bbb authored by Russ Cox's avatar Russ Cox

gofix: new command for updating code to new release

R=bradfitzgo, dsymonds, r, gri, adg
CC=golang-dev
https://golang.org/cl/4282044
parent 66f09fd4
......@@ -41,6 +41,7 @@ CLEANDIRS=\
cgo\
ebnflint\
godoc\
gofix\
gofmt\
goinstall\
gotype\
......
# Copyright 2011 The Go Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
include ../../Make.inc
TARG=gofix
GOFILES=\
fix.go\
main.go\
httpserver.go\
include ../../Make.cmd
test:
gotest
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Gofix finds Go programs that use old APIs and rewrites them to use
newer ones. After you update to a new Go release, gofix helps make
the necessary changes to your programs.
Usage:
gofix [-r name,...] [path ...]
Without an explicit path, gofix reads standard input and writes the
result to standard output.
If the named path is a file, gofix rewrites the named files in place.
If the named path is a directory, gofix rewrites all .go files in that
directory tree. When gofix rewrites a file, it prints a line to standard
error giving the name of the file and the rewrite applied.
The -r flag restricts the set of rewrites considered to those in the
named list. By default gofix considers all known rewrites. Gofix's
rewrites are idempotent, so that it is safe to apply gofix to updated
or partially updated code even without using the -r flag.
Gofix prints the full list of fixes it can apply in its help output;
to see them, run godoc -?.
Gofix does not make backup copies of the files that it edits.
Instead, use a version control system's ``diff'' functionality to inspect
the changes that gofix makes before committing them.
*/
package documentation
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"go/ast"
"go/token"
"os"
)
type fix struct {
name string
f func(*ast.File) bool
desc string
}
// main runs sort.Sort(fixes) after init process is done.
type fixlist []fix
func (f fixlist) Len() int { return len(f) }
func (f fixlist) Swap(i, j int) { f[i], f[j] = f[j], f[i] }
func (f fixlist) Less(i, j int) bool { return f[i].name < f[j].name }
var fixes fixlist
func register(f fix) {
fixes = append(fixes, f)
}
// rewrite walks the AST x, calling visit(y) for each node y in the tree but
// also with a pointer to each ast.Expr, in a bottom-up traversal.
func rewrite(x interface{}, visit func(interface{})) {
switch n := x.(type) {
case *ast.Expr:
rewrite(*n, visit)
// everything else just recurses
default:
panic(fmt.Errorf("unexpected type %T in walk", x, visit))
case nil:
// These are ordered and grouped to match ../../pkg/go/ast/ast.go
case *ast.Field:
rewrite(&n.Type, visit)
case *ast.FieldList:
for _, field := range n.List {
rewrite(field, visit)
}
case *ast.BadExpr:
case *ast.Ident:
case *ast.Ellipsis:
case *ast.BasicLit:
case *ast.FuncLit:
rewrite(n.Type, visit)
rewrite(n.Body, visit)
case *ast.CompositeLit:
rewrite(&n.Type, visit)
rewrite(n.Elts, visit)
case *ast.ParenExpr:
rewrite(&n.X, visit)
case *ast.SelectorExpr:
rewrite(&n.X, visit)
case *ast.IndexExpr:
rewrite(&n.X, visit)
rewrite(&n.Index, visit)
case *ast.SliceExpr:
rewrite(&n.X, visit)
if n.Low != nil {
rewrite(&n.Low, visit)
}
if n.High != nil {
rewrite(&n.High, visit)
}
case *ast.TypeAssertExpr:
rewrite(&n.X, visit)
rewrite(&n.Type, visit)
case *ast.CallExpr:
rewrite(&n.Fun, visit)
rewrite(n.Args, visit)
case *ast.StarExpr:
rewrite(&n.X, visit)
case *ast.UnaryExpr:
rewrite(&n.X, visit)
case *ast.BinaryExpr:
rewrite(&n.X, visit)
rewrite(&n.Y, visit)
case *ast.KeyValueExpr:
rewrite(&n.Key, visit)
rewrite(&n.Value, visit)
case *ast.ArrayType:
rewrite(&n.Len, visit)
rewrite(&n.Elt, visit)
case *ast.StructType:
rewrite(n.Fields, visit)
case *ast.FuncType:
rewrite(n.Params, visit)
if n.Results != nil {
rewrite(n.Results, visit)
}
case *ast.InterfaceType:
rewrite(n.Methods, visit)
case *ast.MapType:
rewrite(&n.Key, visit)
rewrite(&n.Value, visit)
case *ast.ChanType:
rewrite(&n.Value, visit)
case *ast.BadStmt:
case *ast.DeclStmt:
rewrite(n.Decl, visit)
case *ast.EmptyStmt:
case *ast.LabeledStmt:
rewrite(n.Stmt, visit)
case *ast.ExprStmt:
rewrite(&n.X, visit)
case *ast.SendStmt:
rewrite(&n.Chan, visit)
rewrite(&n.Value, visit)
case *ast.IncDecStmt:
rewrite(&n.X, visit)
case *ast.AssignStmt:
rewrite(n.Lhs, visit)
if len(n.Lhs) == 2 && len(n.Rhs) == 1 {
rewrite(n.Rhs, visit)
} else {
rewrite(n.Rhs, visit)
}
case *ast.GoStmt:
rewrite(n.Call, visit)
case *ast.DeferStmt:
rewrite(n.Call, visit)
case *ast.ReturnStmt:
rewrite(n.Results, visit)
case *ast.BranchStmt:
case *ast.BlockStmt:
rewrite(n.List, visit)
case *ast.IfStmt:
rewrite(n.Init, visit)
rewrite(&n.Cond, visit)
rewrite(n.Body, visit)
rewrite(n.Else, visit)
case *ast.CaseClause:
rewrite(n.Values, visit)
rewrite(n.Body, visit)
case *ast.SwitchStmt:
rewrite(n.Init, visit)
rewrite(&n.Tag, visit)
rewrite(n.Body, visit)
case *ast.TypeCaseClause:
rewrite(n.Types, visit)
rewrite(n.Body, visit)
case *ast.TypeSwitchStmt:
rewrite(n.Init, visit)
rewrite(n.Assign, visit)
rewrite(n.Body, visit)
case *ast.CommClause:
rewrite(n.Comm, visit)
rewrite(n.Body, visit)
case *ast.SelectStmt:
rewrite(n.Body, visit)
case *ast.ForStmt:
rewrite(n.Init, visit)
rewrite(&n.Cond, visit)
rewrite(n.Post, visit)
rewrite(n.Body, visit)
case *ast.RangeStmt:
rewrite(&n.Key, visit)
rewrite(&n.Value, visit)
rewrite(&n.X, visit)
rewrite(n.Body, visit)
case *ast.ImportSpec:
case *ast.ValueSpec:
rewrite(&n.Type, visit)
rewrite(n.Values, visit)
case *ast.TypeSpec:
rewrite(&n.Type, visit)
case *ast.BadDecl:
case *ast.GenDecl:
rewrite(n.Specs, visit)
case *ast.FuncDecl:
if n.Recv != nil {
rewrite(n.Recv, visit)
}
rewrite(n.Type, visit)
if n.Body != nil {
rewrite(n.Body, visit)
}
case *ast.File:
rewrite(n.Decls, visit)
case *ast.Package:
for _, file := range n.Files {
rewrite(file, visit)
}
case []ast.Decl:
for _, d := range n {
rewrite(d, visit)
}
case []ast.Expr:
for i := range n {
rewrite(&n[i], visit)
}
case []ast.Stmt:
for _, s := range n {
rewrite(s, visit)
}
case []ast.Spec:
for _, s := range n {
rewrite(s, visit)
}
}
visit(x)
}
func imports(f *ast.File, path string) bool {
for _, decl := range f.Decls {
d, ok := decl.(*ast.GenDecl)
if !ok {
continue
}
for _, spec := range d.Specs {
s, ok := spec.(*ast.ImportSpec)
if !ok {
continue
}
if string(s.Path.Value) == `"`+path+`"` {
return true
}
}
}
return false
}
func isPkgDot(t ast.Expr, pkg, name string) bool {
sel, ok := t.(*ast.SelectorExpr)
if !ok {
return false
}
return isName(sel.X, pkg) && sel.Sel.String() == name
}
func isPtrPkgDot(t ast.Expr, pkg, name string) bool {
ptr, ok := t.(*ast.StarExpr)
if !ok {
return false
}
return isPkgDot(ptr.X, pkg, name)
}
func isName(n ast.Expr, name string) bool {
id, ok := n.(*ast.Ident)
if !ok {
return false
}
return id.String() == name
}
func refersTo(n ast.Node, x *ast.Ident) bool {
id, ok := n.(*ast.Ident)
if !ok {
return false
}
return id.String() == x.String()
}
func isBlank(n ast.Expr) bool {
return isName(n, "_")
}
func isEmptyString(n ast.Expr) bool {
lit, ok := n.(*ast.BasicLit)
if !ok {
return false
}
if lit.Kind != token.STRING {
return false
}
s := string(lit.Value)
return s == `""` || s == "``"
}
func warn(pos token.Pos, msg string, args ...interface{}) {
s := ""
if pos.IsValid() {
s = fmt.Sprintf("%s: ", fset.Position(pos).String())
}
fmt.Fprintf(os.Stderr, "%s"+msg+"\n", append([]interface{}{s}, args...))
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"go/ast"
"go/token"
)
var httpserverFix = fix{
"httpserver",
httpserver,
`Adapt http server methods and functions to changes
made to the http ResponseWriter interface.
http://codereview.appspot.com/4245064 Hijacker
http://codereview.appspot.com/4239076 Header
http://codereview.appspot.com/4239077 Flusher
http://codereview.appspot.com/4248075 RemoteAddr, UsingTLS
`,
}
func init() {
register(httpserverFix)
}
func httpserver(f *ast.File) bool {
if !imports(f, "http") {
return false
}
fixed := false
for _, decl := range f.Decls {
fn, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
w, req, ok := isServeHTTP(fn)
if !ok {
continue
}
rewrite(fn.Body, func(n interface{}) {
// Want to replace expression sometimes,
// so record pointer to it for updating below.
ptr, ok := n.(*ast.Expr)
if ok {
n = *ptr
}
// Look for w.UsingTLS() and w.Remoteaddr().
call, ok := n.(*ast.CallExpr)
if !ok || len(call.Args) != 0 {
return
}
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return
}
if !refersTo(sel.X, w) {
return
}
switch sel.Sel.String() {
case "Hijack":
// replace w with w.(http.Hijacker)
sel.X = &ast.TypeAssertExpr{
X: sel.X,
Type: ast.NewIdent("http.Hijacker"),
}
fixed = true
case "Flush":
// replace w with w.(http.Flusher)
sel.X = &ast.TypeAssertExpr{
X: sel.X,
Type: ast.NewIdent("http.Flusher"),
}
fixed = true
case "UsingTLS":
if ptr == nil {
// can only replace expression if we have pointer to it
break
}
// replace with req.TLS != nil
*ptr = &ast.BinaryExpr{
X: &ast.SelectorExpr{
X: ast.NewIdent(req.String()),
Sel: ast.NewIdent("TLS"),
},
Op: token.NEQ,
Y: ast.NewIdent("nil"),
}
fixed = true
case "RemoteAddr":
if ptr == nil {
// can only replace expression if we have pointer to it
break
}
// replace with req.RemoteAddr
*ptr = &ast.SelectorExpr{
X: ast.NewIdent(req.String()),
Sel: ast.NewIdent("RemoteAddr"),
}
fixed = true
}
})
}
return fixed
}
func isServeHTTP(fn *ast.FuncDecl) (w, req *ast.Ident, ok bool) {
for _, field := range fn.Type.Params.List {
if isPkgDot(field.Type, "http", "ResponseWriter") {
w = field.Names[0]
continue
}
if isPtrPkgDot(field.Type, "http", "Request") {
req = field.Names[0]
continue
}
}
ok = w != nil && req != nil
return
}
package main
func init() {
addTestCases(httpserverTests)
}
var httpserverTests = []testCase{
{
Name: "httpserver.0",
Fn: httpserver,
In: `package main
import "http"
func f(xyz http.ResponseWriter, abc *http.Request, b string) {
xyz.Hijack()
xyz.Flush()
go xyz.Hijack()
defer xyz.Flush()
_ = xyz.UsingTLS()
_ = true == xyz.UsingTLS()
_ = xyz.RemoteAddr()
_ = xyz.RemoteAddr() == "hello"
if xyz.UsingTLS() {
}
}
`,
Out: `package main
import "http"
func f(xyz http.ResponseWriter, abc *http.Request, b string) {
xyz.(http.Hijacker).Hijack()
xyz.(http.Flusher).Flush()
go xyz.(http.Hijacker).Hijack()
defer xyz.(http.Flusher).Flush()
_ = abc.TLS != nil
_ = true == (abc.TLS != nil)
_ = abc.RemoteAddr
_ = abc.RemoteAddr == "hello"
if abc.TLS != nil {
}
}
`,
},
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"flag"
"fmt"
"go/parser"
"go/printer"
"go/scanner"
"go/token"
"io/ioutil"
"os"
"path/filepath"
"sort"
"strings"
)
var (
fset = token.NewFileSet()
exitCode = 0
)
var allowedRewrites = flag.String("r", "",
"restrict the rewrites to this comma-separated list")
var allowed map[string]bool
func usage() {
fmt.Fprintf(os.Stderr, "usage: gofix [-r fixname,...] [path ...]\n")
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
for _, f := range fixes {
fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
desc := strings.TrimSpace(f.desc)
desc = strings.Replace(desc, "\n", "\n\t", -1)
fmt.Fprintf(os.Stderr, "\t%s\n", desc)
}
os.Exit(2)
}
func main() {
sort.Sort(fixes)
flag.Usage = usage
flag.Parse()
if *allowedRewrites != "" {
allowed = make(map[string]bool)
for _, f := range strings.Split(*allowedRewrites, ",", -1) {
allowed[f] = true
}
}
if flag.NArg() == 0 {
if err := processFile("standard input", true); err != nil {
report(err)
}
os.Exit(exitCode)
}
for i := 0; i < flag.NArg(); i++ {
path := flag.Arg(i)
switch dir, err := os.Stat(path); {
case err != nil:
report(err)
case dir.IsRegular():
if err := processFile(path, false); err != nil {
report(err)
}
case dir.IsDirectory():
walkDir(path)
}
}
os.Exit(exitCode)
}
const (
tabWidth = 8
parserMode = parser.ParseComments
printerMode = printer.TabIndent
)
func processFile(filename string, useStdin bool) os.Error {
var f *os.File
var err os.Error
if useStdin {
f = os.Stdin
} else {
f, err = os.Open(filename, os.O_RDONLY, 0)
if err != nil {
return err
}
defer f.Close()
}
src, err := ioutil.ReadAll(f)
if err != nil {
return err
}
file, err := parser.ParseFile(fset, filename, src, parserMode)
if err != nil {
return err
}
fixed := false
var buf bytes.Buffer
for _, fix := range fixes {
if allowed != nil && !allowed[fix.desc] {
continue
}
if fix.f(file) {
fixed = true
fmt.Fprintf(&buf, " %s", fix.name)
}
}
if !fixed {
return nil
}
fmt.Fprintf(os.Stderr, "%s: %s\n", filename, buf.String()[1:])
buf.Reset()
_, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file)
if err != nil {
return err
}
if useStdin {
os.Stdout.Write(buf.Bytes())
return nil
}
return ioutil.WriteFile(f.Name(), buf.Bytes(), 0)
}
func report(err os.Error) {
scanner.PrintError(os.Stderr, err)
exitCode = 2
}
func walkDir(path string) {
v := make(fileVisitor)
go func() {
filepath.Walk(path, v, v)
close(v)
}()
for err := range v {
if err != nil {
report(err)
}
}
}
type fileVisitor chan os.Error
func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool {
return true
}
func (v fileVisitor) VisitFile(path string, f *os.FileInfo) {
if isGoFile(f) {
v <- nil // synchronize error handler
if err := processFile(path, false); err != nil {
v <- err
}
}
}
func isGoFile(f *os.FileInfo) bool {
// ignore non-Go files
return f.IsRegular() && !strings.HasPrefix(f.Name, ".") && strings.HasSuffix(f.Name, ".go")
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"go/ast"
"go/parser"
"go/printer"
"testing"
)
type testCase struct {
Name string
Fn func(*ast.File) bool
In string
Out string
}
var testCases []testCase
func addTestCases(t []testCase) {
testCases = append(testCases, t...)
}
func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out string, fixed, ok bool) {
file, err := parser.ParseFile(fset, desc, in, parserMode)
if err != nil {
t.Errorf("%s: parsing: %v", desc, err)
return
}
var buf bytes.Buffer
buf.Reset()
_, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file)
if err != nil {
t.Errorf("%s: printing: %v", desc, err)
return
}
if s := buf.String(); in != s {
t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
desc, desc, in, desc, s)
return
}
if fn == nil {
for _, fix := range fixes {
if fix.f(file) {
fixed = true
}
}
} else {
fixed = fn(file)
}
buf.Reset()
_, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file)
if err != nil {
t.Errorf("%s: printing: %v", desc, err)
return
}
return buf.String(), fixed, true
}
func TestRewrite(t *testing.T) {
for _, tt := range testCases {
// Apply fix: should get tt.Out.
out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In)
if !ok {
continue
}
if out != tt.Out {
t.Errorf("%s: incorrect output.\n--- have\n%s\n--- want\n%s", tt.Name, out, tt.Out)
continue
}
if changed := out != tt.In; changed != fixed {
t.Errorf("%s: changed=%v != fixed=%v", tt.Name, changed, fixed)
continue
}
// Should not change if run again.
out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out)
if !ok {
continue
}
if fixed2 {
t.Errorf("%s: applied fixes during second round", tt.Name)
continue
}
if out2 != out {
t.Errorf("%s: changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s",
tt.Name, out, out2)
}
}
}
......@@ -155,6 +155,7 @@ DIRS=\
../cmd/cgo\
../cmd/ebnflint\
../cmd/godoc\
../cmd/gofix\
../cmd/gofmt\
../cmd/gotype\
../cmd/goinstall\
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment