Commit 95cbcc5c authored by Ariel Mashraki's avatar Ariel Mashraki Committed by Rob Pike

text/template: support all comparable types in eq

Extends the built-in eq function to support all Go
comparable types.

Fixes #33740

Change-Id: I522310e313e251c4dc6a013d33d7c2034fe2ec8e
Reviewed-on: https://go-review.googlesource.com/c/go/+/193837
Run-TryBot: Rob Pike <r@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarRob Pike <r@golang.org>
parent 434e4caf
...@@ -385,14 +385,12 @@ returning in effect ...@@ -385,14 +385,12 @@ returning in effect
(Unlike with || in Go, however, eq is a function call and all the (Unlike with || in Go, however, eq is a function call and all the
arguments will be evaluated.) arguments will be evaluated.)
The comparison functions work on basic types only (or named basic The comparison functions work on any values whose type Go defines as
types, such as "type Celsius float32"). They implement the Go rules comparable. For basic types such as integers, the rules are relaxed:
for comparison of values, except that size and exact type are size and exact type are ignored, so any integer value, signed or unsigned,
ignored, so any integer value, signed or unsigned, may be compared may be compared with any other integer value. (The arithmetic value is compared,
with any other integer value. (The arithmetic value is compared, not the bit pattern, so all negative integers are less than all unsigned integers.)
not the bit pattern, so all negative integers are less than all However, as usual, one may not compare an int with a float32 and so on.
unsigned integers.) However, as usual, one may not compare an int
with a float32 and so on.
Associated templates Associated templates
......
...@@ -1156,11 +1156,22 @@ var cmpTests = []cmpTest{ ...@@ -1156,11 +1156,22 @@ var cmpTests = []cmpTest{
{"ge .Uthree .NegOne", "true", true}, {"ge .Uthree .NegOne", "true", true},
{"eq (index `x` 0) 'x'", "true", true}, // The example that triggered this rule. {"eq (index `x` 0) 'x'", "true", true}, // The example that triggered this rule.
{"eq (index `x` 0) 'y'", "false", true}, {"eq (index `x` 0) 'y'", "false", true},
{"eq .V1 .V2", "true", true},
{"eq .Ptr .Ptr", "true", true},
{"eq .Ptr .NilPtr", "false", true},
{"eq .NilPtr .NilPtr", "true", true},
{"eq .Iface1 .Iface1", "true", true},
{"eq .Iface1 .Iface2", "false", true},
{"eq .Iface2 .Iface2", "true", true},
// Errors // Errors
{"eq `xy` 1", "", false}, // Different types. {"eq `xy` 1", "", false}, // Different types.
{"eq 2 2.0", "", false}, // Different types. {"eq 2 2.0", "", false}, // Different types.
{"lt true true", "", false}, // Unordered types. {"lt true true", "", false}, // Unordered types.
{"lt 1+0i 1+0i", "", false}, // Unordered types. {"lt 1+0i 1+0i", "", false}, // Unordered types.
{"eq .Ptr 1", "", false}, // Incompatible types.
{"eq .Ptr .NegOne", "", false}, // Incompatible types.
{"eq .Map .Map", "", false}, // Uncomparable types.
{"eq .Map .V1", "", false}, // Uncomparable types.
} }
func TestComparison(t *testing.T) { func TestComparison(t *testing.T) {
...@@ -1168,7 +1179,18 @@ func TestComparison(t *testing.T) { ...@@ -1168,7 +1179,18 @@ func TestComparison(t *testing.T) {
var cmpStruct = struct { var cmpStruct = struct {
Uthree, Ufour uint Uthree, Ufour uint
NegOne, Three int NegOne, Three int
}{3, 4, -1, 3} Ptr, NilPtr *int
Map map[int]int
V1, V2 V
Iface1, Iface2 fmt.Stringer
}{
Uthree: 3,
Ufour: 4,
NegOne: -1,
Three: 3,
Ptr: new(int),
Iface1: b,
}
for _, test := range cmpTests { for _, test := range cmpTests {
text := fmt.Sprintf("{{if %s}}true{{else}}false{{end}}", test.expr) text := fmt.Sprintf("{{if %s}}true{{else}}false{{end}}", test.expr)
tmpl, err := New("empty").Parse(text) tmpl, err := New("empty").Parse(text)
......
...@@ -441,19 +441,18 @@ func basicKind(v reflect.Value) (kind, error) { ...@@ -441,19 +441,18 @@ func basicKind(v reflect.Value) (kind, error) {
// eq evaluates the comparison a == b || a == c || ... // eq evaluates the comparison a == b || a == c || ...
func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) { func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
v1 := indirectInterface(arg1) v1 := indirectInterface(arg1)
k1, err := basicKind(v1) if v1 != zero {
if err != nil { if t1 := v1.Type(); !t1.Comparable() {
return false, err return false, fmt.Errorf("uncomparable type %s: %v", t1, v1)
}
} }
if len(arg2) == 0 { if len(arg2) == 0 {
return false, errNoComparison return false, errNoComparison
} }
k1, _ := basicKind(v1)
for _, arg := range arg2 { for _, arg := range arg2 {
v2 := indirectInterface(arg) v2 := indirectInterface(arg)
k2, err := basicKind(v2) k2, _ := basicKind(v2)
if err != nil {
return false, err
}
truth := false truth := false
if k1 != k2 { if k1 != k2 {
// Special case: Can compare integer values regardless of type's sign. // Special case: Can compare integer values regardless of type's sign.
...@@ -480,7 +479,14 @@ func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) { ...@@ -480,7 +479,14 @@ func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
case uintKind: case uintKind:
truth = v1.Uint() == v2.Uint() truth = v1.Uint() == v2.Uint()
default: default:
panic("invalid kind") if v2 == zero {
truth = v1 == v2
} else {
if t2 := v2.Type(); !t2.Comparable() {
return false, fmt.Errorf("uncomparable type %s: %v", t2, v2)
}
truth = v1.Interface() == v2.Interface()
}
} }
} }
if truth { if truth {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment