Commit 6323a40f authored by Russ Cox's avatar Russ Cox

gofix: test import insertion, deletion

Small change to go/ast, go/parser, go/printer so that
gofix can delete the blank line left from deleting an import.

R=golang-dev, bradfitz, adg
CC=golang-dev
https://golang.org/cl/5321046
parent 8fee9bc8
...@@ -19,7 +19,7 @@ type fix struct { ...@@ -19,7 +19,7 @@ type fix struct {
desc string desc string
} }
// main runs sort.Sort(fixes) after init process is done. // main runs sort.Sort(fixes) before printing list of fixes.
type fixlist []fix type fixlist []fix
func (f fixlist) Len() int { return len(f) } func (f fixlist) Len() int { return len(f) }
...@@ -316,6 +316,20 @@ func importPath(s *ast.ImportSpec) string { ...@@ -316,6 +316,20 @@ func importPath(s *ast.ImportSpec) string {
return "" return ""
} }
// declImports reports whether gen contains an import of path.
func declImports(gen *ast.GenDecl, path string) bool {
if gen.Tok != token.IMPORT {
return false
}
for _, spec := range gen.Specs {
impspec := spec.(*ast.ImportSpec)
if importPath(impspec) == path {
return true
}
}
return false
}
// isPkgDot returns true if t is the expression "pkg.name" // isPkgDot returns true if t is the expression "pkg.name"
// where pkg is an imported identifier. // where pkg is an imported identifier.
func isPkgDot(t ast.Expr, pkg, name string) bool { func isPkgDot(t ast.Expr, pkg, name string) bool {
...@@ -486,14 +500,20 @@ func addImport(f *ast.File, path string) { ...@@ -486,14 +500,20 @@ func addImport(f *ast.File, path string) {
var impdecl *ast.GenDecl var impdecl *ast.GenDecl
// Find an import decl to add to. // Find an import decl to add to.
for _, decl := range f.Decls { var lastImport int = -1
for i, decl := range f.Decls {
gen, ok := decl.(*ast.GenDecl) gen, ok := decl.(*ast.GenDecl)
if ok && gen.Tok == token.IMPORT { if ok && gen.Tok == token.IMPORT {
lastImport = i
// Do not add to import "C", to avoid disrupting the
// association with its doc comment, breaking cgo.
if !declImports(gen, "C") {
impdecl = gen impdecl = gen
break break
} }
} }
}
// No import decl found. Add one. // No import decl found. Add one.
if impdecl == nil { if impdecl == nil {
...@@ -501,8 +521,8 @@ func addImport(f *ast.File, path string) { ...@@ -501,8 +521,8 @@ func addImport(f *ast.File, path string) {
Tok: token.IMPORT, Tok: token.IMPORT,
} }
f.Decls = append(f.Decls, nil) f.Decls = append(f.Decls, nil)
copy(f.Decls[1:], f.Decls) copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:])
f.Decls[0] = impdecl f.Decls[lastImport+1] = impdecl
} }
// Ensure the import decl has parentheses, if needed. // Ensure the import decl has parentheses, if needed.
...@@ -540,7 +560,6 @@ func deleteImport(f *ast.File, path string) { ...@@ -540,7 +560,6 @@ func deleteImport(f *ast.File, path string) {
} }
for j, spec := range gen.Specs { for j, spec := range gen.Specs {
impspec := spec.(*ast.ImportSpec) impspec := spec.(*ast.ImportSpec)
if oldImport != impspec { if oldImport != impspec {
continue continue
} }
...@@ -558,7 +577,13 @@ func deleteImport(f *ast.File, path string) { ...@@ -558,7 +577,13 @@ func deleteImport(f *ast.File, path string) {
} else if len(gen.Specs) == 1 { } else if len(gen.Specs) == 1 {
gen.Lparen = token.NoPos // drop parens gen.Lparen = token.NoPos // drop parens
} }
if j > 0 {
// We deleted an entry but now there will be
// a blank line-sized hole where the import was.
// Close the hole by making the previous
// import appear to "end" where this one did.
gen.Specs[j-1].(*ast.ImportSpec).EndPos = impspec.End()
}
break break
} }
} }
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import "go/ast"
func init() {
addTestCases(importTests, nil)
}
var importTests = []testCase{
{
Name: "import.0",
Fn: addImportFn("os"),
In: `package main
import (
"os"
)
`,
Out: `package main
import (
"os"
)
`,
},
{
Name: "import.1",
Fn: addImportFn("os"),
In: `package main
`,
Out: `package main
import "os"
`,
},
{
Name: "import.2",
Fn: addImportFn("os"),
In: `package main
// Comment
import "C"
`,
Out: `package main
// Comment
import "C"
import "os"
`,
},
{
Name: "import.3",
Fn: addImportFn("os"),
In: `package main
// Comment
import "C"
import (
"io"
"utf8"
)
`,
Out: `package main
// Comment
import "C"
import (
"io"
"os"
"utf8"
)
`,
},
{
Name: "import.4",
Fn: deleteImportFn("os"),
In: `package main
import (
"os"
)
`,
Out: `package main
`,
},
{
Name: "import.5",
Fn: deleteImportFn("os"),
In: `package main
// Comment
import "C"
import "os"
`,
Out: `package main
// Comment
import "C"
`,
},
{
Name: "import.6",
Fn: deleteImportFn("os"),
In: `package main
// Comment
import "C"
import (
"io"
"os"
"utf8"
)
`,
Out: `package main
// Comment
import "C"
import (
"io"
"utf8"
)
`,
},
{
Name: "import.7",
Fn: deleteImportFn("io"),
In: `package main
import (
"io" // a
"os" // b
"utf8" // c
)
`,
Out: `package main
import (
// a
"os" // b
"utf8" // c
)
`,
},
{
Name: "import.8",
Fn: deleteImportFn("os"),
In: `package main
import (
"io" // a
"os" // b
"utf8" // c
)
`,
Out: `package main
import (
"io" // a
// b
"utf8" // c
)
`,
},
{
Name: "import.9",
Fn: deleteImportFn("utf8"),
In: `package main
import (
"io" // a
"os" // b
"utf8" // c
)
`,
Out: `package main
import (
"io" // a
"os" // b
// c
)
`,
},
{
Name: "import.10",
Fn: deleteImportFn("io"),
In: `package main
import (
"io"
"os"
"utf8"
)
`,
Out: `package main
import (
"os"
"utf8"
)
`,
},
{
Name: "import.11",
Fn: deleteImportFn("os"),
In: `package main
import (
"io"
"os"
"utf8"
)
`,
Out: `package main
import (
"io"
"utf8"
)
`,
},
{
Name: "import.12",
Fn: deleteImportFn("utf8"),
In: `package main
import (
"io"
"os"
"utf8"
)
`,
Out: `package main
import (
"io"
"os"
)
`,
},
}
func addImportFn(path string) func(*ast.File) bool {
return func(f *ast.File) bool {
if !imports(f, path) {
addImport(f, path)
return true
}
return false
}
}
func deleteImportFn(path string) func(*ast.File) bool {
return func(f *ast.File) bool {
if imports(f, path) {
deleteImport(f, path)
return true
}
return false
}
}
...@@ -752,6 +752,7 @@ type ( ...@@ -752,6 +752,7 @@ type (
Name *Ident // local package name (including "."); or nil Name *Ident // local package name (including "."); or nil
Path *BasicLit // import path Path *BasicLit // import path
Comment *CommentGroup // line comments; or nil Comment *CommentGroup // line comments; or nil
EndPos token.Pos // end of spec (overrides Path.Pos if nonzero)
} }
// A ValueSpec node represents a constant or variable declaration // A ValueSpec node represents a constant or variable declaration
...@@ -785,7 +786,13 @@ func (s *ImportSpec) Pos() token.Pos { ...@@ -785,7 +786,13 @@ func (s *ImportSpec) Pos() token.Pos {
func (s *ValueSpec) Pos() token.Pos { return s.Names[0].Pos() } func (s *ValueSpec) Pos() token.Pos { return s.Names[0].Pos() }
func (s *TypeSpec) Pos() token.Pos { return s.Name.Pos() } func (s *TypeSpec) Pos() token.Pos { return s.Name.Pos() }
func (s *ImportSpec) End() token.Pos { return s.Path.End() } func (s *ImportSpec) End() token.Pos {
if s.EndPos != 0 {
return s.EndPos
}
return s.Path.End()
}
func (s *ValueSpec) End() token.Pos { func (s *ValueSpec) End() token.Pos {
if n := len(s.Values); n > 0 { if n := len(s.Values); n > 0 {
return s.Values[n-1].End() return s.Values[n-1].End()
......
...@@ -1909,7 +1909,7 @@ func parseImportSpec(p *parser, doc *ast.CommentGroup, _ int) ast.Spec { ...@@ -1909,7 +1909,7 @@ func parseImportSpec(p *parser, doc *ast.CommentGroup, _ int) ast.Spec {
p.expectSemi() // call before accessing p.linecomment p.expectSemi() // call before accessing p.linecomment
// collect imports // collect imports
spec := &ast.ImportSpec{doc, ident, path, p.lineComment} spec := &ast.ImportSpec{doc, ident, path, p.lineComment, token.NoPos}
p.imports = append(p.imports, spec) p.imports = append(p.imports, spec)
return spec return spec
......
...@@ -1278,6 +1278,7 @@ func (p *printer) spec(spec ast.Spec, n int, doIndent bool, multiLine *bool) { ...@@ -1278,6 +1278,7 @@ func (p *printer) spec(spec ast.Spec, n int, doIndent bool, multiLine *bool) {
} }
p.expr(s.Path, multiLine) p.expr(s.Path, multiLine)
p.setComment(s.Comment) p.setComment(s.Comment)
p.print(s.EndPos)
case *ast.ValueSpec: case *ast.ValueSpec:
if n != 1 { if n != 1 {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment