Commit 0f3ea750 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

encoding/json: don't cache value addressability when building first encoder

newTypeEncoder (called once per type and then cached) was
looking at the first value seen of that type's addressability
and caching the encoder decision.  If the first value seen was
addressable and a future one wasn't, it would panic.

Instead, introduce a new wrapper encoder type that checks
CanAddr at runtime.

Fixes #6458

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/13839045
parent e8bbbe08
......@@ -305,7 +305,7 @@ func (e *encodeState) reflectValue(v reflect.Value) {
valueEncoder(v)(e, v, false)
}
type encoderFunc func(e *encodeState, v reflect.Value, _ bool)
type encoderFunc func(e *encodeState, v reflect.Value, quoted bool)
var encoderCache struct {
sync.RWMutex
......@@ -316,11 +316,10 @@ func valueEncoder(v reflect.Value) encoderFunc {
if !v.IsValid() {
return invalidValueEncoder
}
t := v.Type()
return typeEncoder(t, v)
return typeEncoder(v.Type())
}
func typeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
func typeEncoder(t reflect.Type) encoderFunc {
encoderCache.RLock()
f := encoderCache.m[t]
encoderCache.RUnlock()
......@@ -346,7 +345,7 @@ func typeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
// Compute fields without lock.
// Might duplicate effort but won't hold other computations back.
f = newTypeEncoder(t, vx)
f = newTypeEncoder(t, true)
wg.Done()
encoderCache.Lock()
encoderCache.m[t] = f
......@@ -354,38 +353,33 @@ func typeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
return f
}
// newTypeEncoder constructs an encoderFunc for a type.
// The provided vx is an example value of type t. It's the first seen
// value of that type and should not be used to encode. It may be
// zero.
func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
if !vx.IsValid() {
vx = reflect.New(t).Elem()
}
var (
marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
)
_, ok := vx.Interface().(Marshaler)
if ok {
// newTypeEncoder constructs an encoderFunc for a type.
// The returned encoder only checks CanAddr when allowAddr is true.
func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
if t.Implements(marshalerType) {
return marshalerEncoder
}
if vx.Kind() != reflect.Ptr && vx.CanAddr() {
_, ok = vx.Addr().Interface().(Marshaler)
if ok {
return addrMarshalerEncoder
if t.Kind() != reflect.Ptr && allowAddr {
if reflect.PtrTo(t).Implements(marshalerType) {
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
}
}
_, ok = vx.Interface().(encoding.TextMarshaler)
if ok {
if t.Implements(textMarshalerType) {
return textMarshalerEncoder
}
if vx.Kind() != reflect.Ptr && vx.CanAddr() {
_, ok = vx.Addr().Interface().(encoding.TextMarshaler)
if ok {
return addrTextMarshalerEncoder
if t.Kind() != reflect.Ptr && allowAddr {
if reflect.PtrTo(t).Implements(textMarshalerType) {
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
}
}
switch vx.Kind() {
switch t.Kind() {
case reflect.Bool:
return boolEncoder
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
......@@ -401,15 +395,15 @@ func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
case reflect.Interface:
return interfaceEncoder
case reflect.Struct:
return newStructEncoder(t, vx)
return newStructEncoder(t)
case reflect.Map:
return newMapEncoder(t, vx)
return newMapEncoder(t)
case reflect.Slice:
return newSliceEncoder(t, vx)
return newSliceEncoder(t)
case reflect.Array:
return newArrayEncoder(t, vx)
return newArrayEncoder(t)
case reflect.Ptr:
return newPtrEncoder(t, vx)
return newPtrEncoder(t)
default:
return unsupportedTypeEncoder
}
......@@ -593,27 +587,19 @@ func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
}
e.string(f.name)
e.WriteByte(':')
if tenc := se.fieldEncs[i]; tenc != nil {
tenc(e, fv, f.quoted)
} else {
// Slower path.
e.reflectValue(fv)
}
se.fieldEncs[i](e, fv, f.quoted)
}
e.WriteByte('}')
}
func newStructEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
func newStructEncoder(t reflect.Type) encoderFunc {
fields := cachedTypeFields(t)
se := &structEncoder{
fields: fields,
fieldEncs: make([]encoderFunc, len(fields)),
}
for i, f := range fields {
vxf := fieldByIndex(vx, f.index)
if vxf.IsValid() {
se.fieldEncs[i] = typeEncoder(vxf.Type(), vxf)
}
se.fieldEncs[i] = typeEncoder(typeByIndex(t, f.index))
}
return se.encode
}
......@@ -641,11 +627,11 @@ func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
e.WriteByte('}')
}
func newMapEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
func newMapEncoder(t reflect.Type) encoderFunc {
if t.Key().Kind() != reflect.String {
return unsupportedTypeEncoder
}
me := &mapEncoder{typeEncoder(vx.Type().Elem(), reflect.Value{})}
me := &mapEncoder{typeEncoder(t.Elem())}
return me.encode
}
......@@ -684,12 +670,12 @@ func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
se.arrayEnc(e, v, false)
}
func newSliceEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
func newSliceEncoder(t reflect.Type) encoderFunc {
// Byte slices get special treatment; arrays don't.
if vx.Type().Elem().Kind() == reflect.Uint8 {
if t.Elem().Kind() == reflect.Uint8 {
return encodeByteSlice
}
enc := &sliceEncoder{newArrayEncoder(t, vx)}
enc := &sliceEncoder{newArrayEncoder(t)}
return enc.encode
}
......@@ -709,8 +695,8 @@ func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
e.WriteByte(']')
}
func newArrayEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
enc := &arrayEncoder{typeEncoder(t.Elem(), reflect.Value{})}
func newArrayEncoder(t reflect.Type) encoderFunc {
enc := &arrayEncoder{typeEncoder(t.Elem())}
return enc.encode
}
......@@ -726,8 +712,27 @@ func (pe *ptrEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
pe.elemEnc(e, v.Elem(), false)
}
func newPtrEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
enc := &ptrEncoder{typeEncoder(t.Elem(), reflect.Value{})}
func newPtrEncoder(t reflect.Type) encoderFunc {
enc := &ptrEncoder{typeEncoder(t.Elem())}
return enc.encode
}
type condAddrEncoder struct {
canAddrEnc, elseEnc encoderFunc
}
func (ce *condAddrEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
if v.CanAddr() {
ce.canAddrEnc(e, v, quoted)
} else {
ce.elseEnc(e, v, quoted)
}
}
// newCondAddrEncoder returns an encoder that checks whether its value
// CanAddr and delegates to canAddrEnc if so, else to elseEnc.
func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc {
enc := &condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc}
return enc.encode
}
......@@ -763,6 +768,16 @@ func fieldByIndex(v reflect.Value, index []int) reflect.Value {
return v
}
func typeByIndex(t reflect.Type, index []int) reflect.Type {
for _, i := range index {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
t = t.Field(i).Type
}
return t
}
// stringValues is a slice of reflect.Value holding *reflect.StringValue.
// It implements the methods to sort by string.
type stringValues []reflect.Value
......
......@@ -401,3 +401,27 @@ func TestStringBytes(t *testing.T) {
t.Errorf("encodings differ at %#q vs %#q", enc, encBytes)
}
}
func TestIssue6458(t *testing.T) {
type Foo struct {
M RawMessage
}
x := Foo{RawMessage(`"foo"`)}
b, err := Marshal(&x)
if err != nil {
t.Fatal(err)
}
if want := `{"M":"foo"}`; string(b) != want {
t.Errorf("Marshal(&x) = %#q; want %#q", b, want)
}
b, err = Marshal(x)
if err != nil {
t.Fatal(err)
}
if want := `{"M":"ImZvbyI="}`; string(b) != want {
t.Errorf("Marshal(x) = %#q; want %#q", b, want)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment