Commit 3d4b55ad authored by Robert Griesemer's avatar Robert Griesemer

gofmt: minor refactor to permit easy testing

R=rsc
CC=golang-dev
https://golang.org/cl/4397046
parent 7b6ee1a5
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"go/printer" "go/printer"
"go/scanner" "go/scanner"
"go/token" "go/token"
"io"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
...@@ -86,14 +87,23 @@ func isGoFile(f *os.FileInfo) bool { ...@@ -86,14 +87,23 @@ func isGoFile(f *os.FileInfo) bool {
} }
func processFile(f *os.File) os.Error { // If in == nil, the source is the contents of the file with the given filename.
src, err := ioutil.ReadAll(f) func processFile(filename string, in io.Reader, out io.Writer) os.Error {
if in == nil {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
in = f
}
src, err := ioutil.ReadAll(in)
if err != nil { if err != nil {
return err return err
} }
file, err := parser.ParseFile(fset, f.Name(), src, parserMode) file, err := parser.ParseFile(fset, filename, src, parserMode)
if err != nil { if err != nil {
return err return err
} }
...@@ -116,10 +126,10 @@ func processFile(f *os.File) os.Error { ...@@ -116,10 +126,10 @@ func processFile(f *os.File) os.Error {
if !bytes.Equal(src, res) { if !bytes.Equal(src, res) {
// formatting has changed // formatting has changed
if *list { if *list {
fmt.Fprintln(os.Stdout, f.Name()) fmt.Fprintln(out, filename)
} }
if *write { if *write {
err = ioutil.WriteFile(f.Name(), res, 0) err = ioutil.WriteFile(filename, res, 0)
if err != nil { if err != nil {
return err return err
} }
...@@ -127,23 +137,13 @@ func processFile(f *os.File) os.Error { ...@@ -127,23 +137,13 @@ func processFile(f *os.File) os.Error {
} }
if !*list && !*write { if !*list && !*write {
_, err = os.Stdout.Write(res) _, err = out.Write(res)
} }
return err return err
} }
func processFileByName(filename string) os.Error {
file, err := os.Open(filename)
if err != nil {
return err
}
defer file.Close()
return processFile(file)
}
type fileVisitor chan os.Error type fileVisitor chan os.Error
func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool { func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool {
...@@ -154,7 +154,7 @@ func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool { ...@@ -154,7 +154,7 @@ func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool {
func (v fileVisitor) VisitFile(path string, f *os.FileInfo) { func (v fileVisitor) VisitFile(path string, f *os.FileInfo) {
if isGoFile(f) { if isGoFile(f) {
v <- nil // synchronize error handler v <- nil // synchronize error handler
if err := processFileByName(path); err != nil { if err := processFile(path, nil, os.Stdout); err != nil {
v <- err v <- err
} }
} }
...@@ -210,9 +210,10 @@ func gofmtMain() { ...@@ -210,9 +210,10 @@ func gofmtMain() {
initRewrite() initRewrite()
if flag.NArg() == 0 { if flag.NArg() == 0 {
if err := processFile(os.Stdin); err != nil { if err := processFile("<standard input>", os.Stdin, os.Stdout); err != nil {
report(err) report(err)
} }
return
} }
for i := 0; i < flag.NArg(); i++ { for i := 0; i < flag.NArg(); i++ {
...@@ -221,7 +222,7 @@ func gofmtMain() { ...@@ -221,7 +222,7 @@ func gofmtMain() {
case err != nil: case err != nil:
report(err) report(err)
case dir.IsRegular(): case dir.IsRegular():
if err := processFileByName(path); err != nil { if err := processFile(path, nil, os.Stdout); err != nil {
report(err) report(err)
} }
case dir.IsDirectory(): case dir.IsDirectory():
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment