Commit 558e7fc3 authored by Russ Cox's avatar Russ Cox

various: avoid func compare

R=gri, r, bradfitz
CC=golang-dev
https://golang.org/cl/5371074
parent c017a829
......@@ -34,7 +34,7 @@ func addTestCases(t []testCase, fn func(*ast.File) bool) {
func fnop(*ast.File) bool { return false }
func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out string, fixed, ok bool) {
func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) {
file, err := parser.ParseFile(fset, desc, in, parserMode)
if err != nil {
t.Errorf("%s: parsing: %v", desc, err)
......@@ -46,7 +46,7 @@ func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out
t.Errorf("%s: printing: %v", desc, err)
return
}
if s := string(outb); in != s && fn != fnop {
if s := string(outb); in != s && mustBeGofmt {
t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
desc, desc, in, desc, s)
tdiff(t, in, s)
......@@ -75,13 +75,13 @@ func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out
func TestRewrite(t *testing.T) {
for _, tt := range testCases {
// Apply fix: should get tt.Out.
out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In)
out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true)
if !ok {
continue
}
// reformat to get printing right
out, _, ok = parseFixPrint(t, fnop, tt.Name, out)
out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false)
if !ok {
continue
}
......@@ -101,7 +101,7 @@ func TestRewrite(t *testing.T) {
}
// Should not change if run again.
out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out)
out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true)
if !ok {
continue
}
......
......@@ -138,6 +138,7 @@ func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[interface{}]string, ass
assign = make(map[string][]interface{})
cfg1 := &TypeConfig{}
*cfg1 = *cfg // make copy so we can add locally
copied := false
// gather function declarations
for _, decl := range f.Decls {
......@@ -185,7 +186,8 @@ func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[interface{}]string, ass
if cfg1.Type[s.Name.Name] != nil {
break
}
if cfg1.Type == cfg.Type || cfg1.Type == nil {
if !copied {
copied = true
// Copy map lazily: it's time.
cfg1.Type = make(map[string]*Type)
for k, v := range cfg.Type {
......
......@@ -662,48 +662,49 @@ func TestRunes(t *testing.T) {
}
type TrimTest struct {
f func([]byte, string) []byte
f string
in, cutset, out string
}
var trimTests = []TrimTest{
{Trim, "abba", "a", "bb"},
{Trim, "abba", "ab", ""},
{TrimLeft, "abba", "ab", ""},
{TrimRight, "abba", "ab", ""},
{TrimLeft, "abba", "a", "bba"},
{TrimRight, "abba", "a", "abb"},
{Trim, "<tag>", "<>", "tag"},
{Trim, "* listitem", " *", "listitem"},
{Trim, `"quote"`, `"`, "quote"},
{Trim, "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"},
{"Trim", "abba", "a", "bb"},
{"Trim", "abba", "ab", ""},
{"TrimLeft", "abba", "ab", ""},
{"TrimRight", "abba", "ab", ""},
{"TrimLeft", "abba", "a", "bba"},
{"TrimRight", "abba", "a", "abb"},
{"Trim", "<tag>", "<>", "tag"},
{"Trim", "* listitem", " *", "listitem"},
{"Trim", `"quote"`, `"`, "quote"},
{"Trim", "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"},
//empty string tests
{Trim, "abba", "", "abba"},
{Trim, "", "123", ""},
{Trim, "", "", ""},
{TrimLeft, "abba", "", "abba"},
{TrimLeft, "", "123", ""},
{TrimLeft, "", "", ""},
{TrimRight, "abba", "", "abba"},
{TrimRight, "", "123", ""},
{TrimRight, "", "", ""},
{TrimRight, "☺\xc0", "☺", "☺\xc0"},
{"Trim", "abba", "", "abba"},
{"Trim", "", "123", ""},
{"Trim", "", "", ""},
{"TrimLeft", "abba", "", "abba"},
{"TrimLeft", "", "123", ""},
{"TrimLeft", "", "", ""},
{"TrimRight", "abba", "", "abba"},
{"TrimRight", "", "123", ""},
{"TrimRight", "", "", ""},
{"TrimRight", "☺\xc0", "☺", "☺\xc0"},
}
func TestTrim(t *testing.T) {
for _, tc := range trimTests {
actual := string(tc.f([]byte(tc.in), tc.cutset))
var name string
switch tc.f {
case Trim:
name = "Trim"
case TrimLeft:
name = "TrimLeft"
case TrimRight:
name = "TrimRight"
name := tc.f
var f func([]byte, string) []byte
switch name {
case "Trim":
f = Trim
case "TrimLeft":
f = TrimLeft
case "TrimRight":
f = TrimRight
default:
t.Error("Undefined trim function")
t.Error("Undefined trim function %s", name)
}
actual := string(f([]byte(tc.in), tc.cutset))
if actual != tc.out {
t.Errorf("%s(%q, %q) = %q; want %q", name, tc.in, tc.cutset, actual, tc.out)
}
......
......@@ -48,8 +48,9 @@ const (
type encoder struct {
// w is the writer that compressed bytes are written to.
w writer
// write, bits, nBits and width are the state for converting a code stream
// into a byte stream.
// order, write, bits, nBits and width are the state for
// converting a code stream into a byte stream.
order Order
write func(*encoder, uint32) error
bits uint32
nBits uint
......@@ -213,7 +214,7 @@ func (e *encoder) Close() error {
}
// Write the final bits.
if e.nBits > 0 {
if e.write == (*encoder).writeMSB {
if e.order == MSB {
e.bits >>= 24
}
if err := e.w.WriteByte(uint8(e.bits)); err != nil {
......@@ -249,6 +250,7 @@ func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
lw := uint(litWidth)
return &encoder{
w: bw,
order: order,
write: write,
width: 1 + lw,
litWidth: lw,
......
......@@ -227,7 +227,7 @@ func (d *decodeState) value(v reflect.Value) {
// d.scan thinks we're still at the beginning of the item.
// Feed in an empty string - the shortest, simplest value -
// so that it knows we got to the end of the value.
if d.scan.step == stateRedo {
if d.scan.redo {
panic("redo")
}
d.scan.step(&d.scan, '"')
......
......@@ -80,6 +80,9 @@ type scanner struct {
// on a 64-bit Mac Mini, and it's nicer to read.
step func(*scanner, int) int
// Reached end of top-level value.
endTop bool
// Stack of what we're in the middle of - array values, object keys, object values.
parseState []int
......@@ -87,6 +90,7 @@ type scanner struct {
err error
// 1-byte redo (see undo method)
redo bool
redoCode int
redoState func(*scanner, int) int
......@@ -135,6 +139,8 @@ func (s *scanner) reset() {
s.step = stateBeginValue
s.parseState = s.parseState[0:0]
s.err = nil
s.redo = false
s.endTop = false
}
// eof tells the scanner that the end of input has been reached.
......@@ -143,11 +149,11 @@ func (s *scanner) eof() int {
if s.err != nil {
return scanError
}
if s.step == stateEndTop {
if s.endTop {
return scanEnd
}
s.step(s, ' ')
if s.step == stateEndTop {
if s.endTop {
return scanEnd
}
if s.err == nil {
......@@ -166,8 +172,10 @@ func (s *scanner) pushParseState(p int) {
func (s *scanner) popParseState() {
n := len(s.parseState) - 1
s.parseState = s.parseState[0:n]
s.redo = false
if n == 0 {
s.step = stateEndTop
s.endTop = true
} else {
s.step = stateEndValue
}
......@@ -269,6 +277,7 @@ func stateEndValue(s *scanner, c int) int {
if n == 0 {
// Completed top-level before the current byte.
s.step = stateEndTop
s.endTop = true
return stateEndTop(s, c)
}
if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') {
......@@ -606,16 +615,18 @@ func quoteChar(c int) string {
// undo causes the scanner to return scanCode from the next state transition.
// This gives callers a simple 1-byte undo mechanism.
func (s *scanner) undo(scanCode int) {
if s.step == stateRedo {
panic("invalid use of scanner")
if s.redo {
panic("json: invalid use of scanner")
}
s.redoCode = scanCode
s.redoState = s.step
s.step = stateRedo
s.redo = true
}
// stateRedo helps implement the scanner's 1-byte undo.
func stateRedo(s *scanner, c int) int {
s.redo = false
s.step = s.redoState
return s.redoCode
}
......@@ -24,7 +24,7 @@ func exportFilter(name string) bool {
// it returns false otherwise.
//
func FileExports(src *File) bool {
return FilterFile(src, exportFilter)
return filterFile(src, exportFilter, true)
}
// PackageExports trims the AST for a Go package in place such that
......@@ -35,7 +35,7 @@ func FileExports(src *File) bool {
// it returns false otherwise.
//
func PackageExports(pkg *Package) bool {
return FilterPackage(pkg, exportFilter)
return filterPackage(pkg, exportFilter, true)
}
// ----------------------------------------------------------------------------
......@@ -72,7 +72,7 @@ func fieldName(x Expr) *Ident {
return nil
}
func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
func filterFieldList(fields *FieldList, filter Filter, export bool) (removedFields bool) {
if fields == nil {
return false
}
......@@ -93,8 +93,8 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
keepField = len(f.Names) > 0
}
if keepField {
if filter == exportFilter {
filterType(f.Type, filter)
if export {
filterType(f.Type, filter, export)
}
list[j] = f
j++
......@@ -107,84 +107,84 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
return
}
func filterParamList(fields *FieldList, filter Filter) bool {
func filterParamList(fields *FieldList, filter Filter, export bool) bool {
if fields == nil {
return false
}
var b bool
for _, f := range fields.List {
if filterType(f.Type, filter) {
if filterType(f.Type, filter, export) {
b = true
}
}
return b
}
func filterType(typ Expr, f Filter) bool {
func filterType(typ Expr, f Filter, export bool) bool {
switch t := typ.(type) {
case *Ident:
return f(t.Name)
case *ParenExpr:
return filterType(t.X, f)
return filterType(t.X, f, export)
case *ArrayType:
return filterType(t.Elt, f)
return filterType(t.Elt, f, export)
case *StructType:
if filterFieldList(t.Fields, f) {
if filterFieldList(t.Fields, f, export) {
t.Incomplete = true
}
return len(t.Fields.List) > 0
case *FuncType:
b1 := filterParamList(t.Params, f)
b2 := filterParamList(t.Results, f)
b1 := filterParamList(t.Params, f, export)
b2 := filterParamList(t.Results, f, export)
return b1 || b2
case *InterfaceType:
if filterFieldList(t.Methods, f) {
if filterFieldList(t.Methods, f, export) {
t.Incomplete = true
}
return len(t.Methods.List) > 0
case *MapType:
b1 := filterType(t.Key, f)
b2 := filterType(t.Value, f)
b1 := filterType(t.Key, f, export)
b2 := filterType(t.Value, f, export)
return b1 || b2
case *ChanType:
return filterType(t.Value, f)
return filterType(t.Value, f, export)
}
return false
}
func filterSpec(spec Spec, f Filter) bool {
func filterSpec(spec Spec, f Filter, export bool) bool {
switch s := spec.(type) {
case *ValueSpec:
s.Names = filterIdentList(s.Names, f)
if len(s.Names) > 0 {
if f == exportFilter {
filterType(s.Type, f)
if export {
filterType(s.Type, f, export)
}
return true
}
case *TypeSpec:
if f(s.Name.Name) {
if f == exportFilter {
filterType(s.Type, f)
if export {
filterType(s.Type, f, export)
}
return true
}
if f != exportFilter {
if !export {
// For general filtering (not just exports),
// filter type even if name is not filtered
// out.
// If the type contains filtered elements,
// keep the declaration.
return filterType(s.Type, f)
return filterType(s.Type, f, export)
}
}
return false
}
func filterSpecList(list []Spec, f Filter) []Spec {
func filterSpecList(list []Spec, f Filter, export bool) []Spec {
j := 0
for _, s := range list {
if filterSpec(s, f) {
if filterSpec(s, f, export) {
list[j] = s
j++
}
......@@ -200,9 +200,13 @@ func filterSpecList(list []Spec, f Filter) []Spec {
// filtering; it returns false otherwise.
//
func FilterDecl(decl Decl, f Filter) bool {
return filterDecl(decl, f, false)
}
func filterDecl(decl Decl, f Filter, export bool) bool {
switch d := decl.(type) {
case *GenDecl:
d.Specs = filterSpecList(d.Specs, f)
d.Specs = filterSpecList(d.Specs, f, export)
return len(d.Specs) > 0
case *FuncDecl:
return f(d.Name.Name)
......@@ -221,9 +225,13 @@ func FilterDecl(decl Decl, f Filter) bool {
// left after filtering; it returns false otherwise.
//
func FilterFile(src *File, f Filter) bool {
return filterFile(src, f, false)
}
func filterFile(src *File, f Filter, export bool) bool {
j := 0
for _, d := range src.Decls {
if FilterDecl(d, f) {
if filterDecl(d, f, export) {
src.Decls[j] = d
j++
}
......@@ -244,9 +252,13 @@ func FilterFile(src *File, f Filter) bool {
// left after filtering; it returns false otherwise.
//
func FilterPackage(pkg *Package, f Filter) bool {
return filterPackage(pkg, f, false)
}
func filterPackage(pkg *Package, f Filter, export bool) bool {
hasDecls := false
for _, src := range pkg.Files {
if FilterFile(src, f) {
if filterFile(src, f, export) {
hasDecls = true
}
}
......
......@@ -9,7 +9,7 @@ package net
var supportsIPv6, supportsIPv4map = probeIPv6Stack()
func firstFavoriteAddr(filter func(IP) IP, addrs []string) (addr IP) {
if filter == anyaddr {
if filter == nil {
// We'll take any IP address, but since the dialing code
// does not yet try multiple addresses, prefer to use
// an IPv4 address if possible. This is especially relevant
......@@ -113,7 +113,7 @@ func hostPortToIP(net, hostport string) (ip IP, iport int, err error) {
// Try as an IP address.
addr = ParseIP(host)
if addr == nil {
filter := anyaddr
var filter func(IP) IP
if net != "" && net[len(net)-1] == '4' {
filter = ipv4only
}
......
......@@ -489,46 +489,47 @@ func TestSpecialCase(t *testing.T) {
func TestTrimSpace(t *testing.T) { runStringTests(t, TrimSpace, "TrimSpace", trimSpaceTests) }
var trimTests = []struct {
f func(string, string) string
f string
in, cutset, out string
}{
{Trim, "abba", "a", "bb"},
{Trim, "abba", "ab", ""},
{TrimLeft, "abba", "ab", ""},
{TrimRight, "abba", "ab", ""},
{TrimLeft, "abba", "a", "bba"},
{TrimRight, "abba", "a", "abb"},
{Trim, "<tag>", "<>", "tag"},
{Trim, "* listitem", " *", "listitem"},
{Trim, `"quote"`, `"`, "quote"},
{Trim, "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"},
{"Trim", "abba", "a", "bb"},
{"Trim", "abba", "ab", ""},
{"TrimLeft", "abba", "ab", ""},
{"TrimRight", "abba", "ab", ""},
{"TrimLeft", "abba", "a", "bba"},
{"TrimRight", "abba", "a", "abb"},
{"Trim", "<tag>", "<>", "tag"},
{"Trim", "* listitem", " *", "listitem"},
{"Trim", `"quote"`, `"`, "quote"},
{"Trim", "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"},
//empty string tests
{Trim, "abba", "", "abba"},
{Trim, "", "123", ""},
{Trim, "", "", ""},
{TrimLeft, "abba", "", "abba"},
{TrimLeft, "", "123", ""},
{TrimLeft, "", "", ""},
{TrimRight, "abba", "", "abba"},
{TrimRight, "", "123", ""},
{TrimRight, "", "", ""},
{TrimRight, "☺\xc0", "☺", "☺\xc0"},
{"Trim", "abba", "", "abba"},
{"Trim", "", "123", ""},
{"Trim", "", "", ""},
{"TrimLeft", "abba", "", "abba"},
{"TrimLeft", "", "123", ""},
{"TrimLeft", "", "", ""},
{"TrimRight", "abba", "", "abba"},
{"TrimRight", "", "123", ""},
{"TrimRight", "", "", ""},
{"TrimRight", "☺\xc0", "☺", "☺\xc0"},
}
func TestTrim(t *testing.T) {
for _, tc := range trimTests {
actual := tc.f(tc.in, tc.cutset)
var name string
switch tc.f {
case Trim:
name = "Trim"
case TrimLeft:
name = "TrimLeft"
case TrimRight:
name = "TrimRight"
name := tc.f
var f func(string, string) string
switch name {
case "Trim":
f = Trim
case "TrimLeft":
f = TrimLeft
case "TrimRight":
f = TrimRight
default:
t.Error("Undefined trim function")
t.Error("Undefined trim function %s", name)
}
actual := f(tc.in, tc.cutset)
if actual != tc.out {
t.Errorf("%s(%q, %q) = %q; want %q", name, tc.in, tc.cutset, actual, tc.out)
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment