Commit 64c9ee98 authored by Daniel Martí's avatar Daniel Martí Committed by Andrew Bonventre

encoding/json: error when encoding a pointer cycle

Otherwise we'd panic with a stack overflow.

Most programs are in control of the data being encoded and can ensure
there are no cycles, but sometimes it's not that simple. For example,
running a user's html template with script tags can easily result in
crashes if the user can find a pointer cycle.

Adding the checks via a map to every ptrEncoder.encode call slowed down
the benchmarks below by a noticeable 13%. Instead, only start doing the
relatively expensive pointer cycle checks if we're many levels of
pointers deep in an encode state.

A threshold of 1000 is small enough to capture pointer cycles before
they're a problem (the goroutine stack limit is currently 1GB, and I
needed close to a million levels to reach it). Yet it's large enough
that reasonable uses of the json encoder only see a tiny 1% slow-down
due to the added ptrLevel field and check.

	name           old time/op    new time/op    delta
	CodeEncoder-8    2.34ms ± 1%    2.37ms ± 0%  +1.05%  (p=0.000 n=10+10)
	CodeMarshal-8    2.42ms ± 1%    2.44ms ± 0%  +1.10%  (p=0.000 n=10+10)

	name           old speed      new speed      delta
	CodeEncoder-8   829MB/s ± 1%   820MB/s ± 0%  -1.04%  (p=0.000 n=10+10)
	CodeMarshal-8   803MB/s ± 1%   795MB/s ± 0%  -1.09%  (p=0.000 n=10+10)

	name           old alloc/op   new alloc/op   delta
	CodeEncoder-8    43.1kB ± 8%    42.5kB ±10%    ~     (p=0.989 n=10+10)
	CodeMarshal-8    1.99MB ± 0%    1.99MB ± 0%    ~     (p=0.254 n=9+6)

	name           old allocs/op  new allocs/op  delta
	CodeEncoder-8      0.00           0.00         ~     (all equal)
	CodeMarshal-8      1.00 ± 0%      1.00 ± 0%    ~     (all equal)

Finally, add a few tests to ensure that the code handles the edge cases
properly.

Fixes #10769.

Change-Id: I73d48e0cf6ea140127ea031f2dbae6e6a55e58b8
Reviewed-on: https://go-review.googlesource.com/c/go/+/187920
Run-TryBot: Daniel Martí <mvdan@mvdan.cc>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Reviewed-by: default avatarAndrew Bonventre <andybons@golang.org>
parent f5114675
......@@ -153,7 +153,7 @@ import (
//
// JSON cannot represent cyclic data structures and Marshal does not
// handle them. Passing cyclic structures to Marshal will result in
// an infinite recursion.
// an error.
//
func Marshal(v interface{}) ([]byte, error) {
e := newEncodeState()
......@@ -285,17 +285,31 @@ var hex = "0123456789abcdef"
type encodeState struct {
bytes.Buffer // accumulated output
scratch [64]byte
// Keep track of what pointers we've seen in the current recursive call
// path, to avoid cycles that could lead to a stack overflow. Only do
// the relatively expensive map operations if ptrLevel is larger than
// startDetectingCyclesAfter, so that we skip the work if we're within a
// reasonable amount of nested pointers deep.
ptrLevel uint
ptrSeen map[interface{}]struct{}
}
const startDetectingCyclesAfter = 1000
var encodeStatePool sync.Pool
func newEncodeState() *encodeState {
if v := encodeStatePool.Get(); v != nil {
e := v.(*encodeState)
e.Reset()
if len(e.ptrSeen) > 0 {
panic("ptrEncoder.encode should have emptied ptrSeen via defers")
}
e.ptrLevel = 0
return e
}
return new(encodeState)
return &encodeState{ptrSeen: make(map[interface{}]struct{})}
}
// jsonError is an error wrapper type for internal use only.
......@@ -887,7 +901,18 @@ func (pe ptrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
e.WriteString("null")
return
}
if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
// We're a large number of nested ptrEncoder.encode calls deep;
// start checking if we've run into a pointer cycle.
ptr := v.Interface()
if _, ok := e.ptrSeen[ptr]; ok {
e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
}
e.ptrSeen[ptr] = struct{}{}
defer delete(e.ptrSeen, ptr)
}
pe.elemEnc(e, v.Elem(), opts)
e.ptrLevel--
}
func newPtrEncoder(t reflect.Type) encoderFunc {
......
......@@ -138,10 +138,45 @@ func TestEncodeRenamedByteSlice(t *testing.T) {
}
}
type SamePointerNoCycle struct {
Ptr1, Ptr2 *SamePointerNoCycle
}
var samePointerNoCycle = &SamePointerNoCycle{}
type PointerCycle struct {
Ptr *PointerCycle
}
var pointerCycle = &PointerCycle{}
type PointerCycleIndirect struct {
Ptrs []interface{}
}
var pointerCycleIndirect = &PointerCycleIndirect{}
func init() {
ptr := &SamePointerNoCycle{}
samePointerNoCycle.Ptr1 = ptr
samePointerNoCycle.Ptr2 = ptr
pointerCycle.Ptr = pointerCycle
pointerCycleIndirect.Ptrs = []interface{}{pointerCycleIndirect}
}
func TestSamePointerNoCycle(t *testing.T) {
if _, err := Marshal(samePointerNoCycle); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
var unsupportedValues = []interface{}{
math.NaN(),
math.Inf(-1),
math.Inf(1),
pointerCycle,
pointerCycleIndirect,
}
func TestUnsupportedValues(t *testing.T) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment