Commit afde71cf authored by Olivier Saingre's avatar Olivier Saingre Committed by Brad Fitzpatrick

encoding/xml: make sure Encoder.Encode reports Write errors.

Fixes #4112.

R=remyoudompheng, daniel.morsing, dave, rsc
CC=golang-dev
https://golang.org/cl/7085053
parent 8eb80914
...@@ -193,7 +193,9 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error { ...@@ -193,7 +193,9 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
if xmlns != "" { if xmlns != "" {
p.WriteString(` xmlns="`) p.WriteString(` xmlns="`)
// TODO: EscapeString, to avoid the allocation. // TODO: EscapeString, to avoid the allocation.
Escape(p, []byte(xmlns)) if err := EscapeText(p, []byte(xmlns)); err != nil {
return err
}
p.WriteByte('"') p.WriteByte('"')
} }
...@@ -252,19 +254,22 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error { ...@@ -252,19 +254,22 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error {
p.WriteString(strconv.FormatFloat(val.Float(), 'g', -1, val.Type().Bits())) p.WriteString(strconv.FormatFloat(val.Float(), 'g', -1, val.Type().Bits()))
case reflect.String: case reflect.String:
// TODO: Add EscapeString. // TODO: Add EscapeString.
Escape(p, []byte(val.String())) EscapeText(p, []byte(val.String()))
case reflect.Bool: case reflect.Bool:
p.WriteString(strconv.FormatBool(val.Bool())) p.WriteString(strconv.FormatBool(val.Bool()))
case reflect.Array: case reflect.Array:
// will be [...]byte // will be [...]byte
bytes := make([]byte, val.Len()) var bytes []byte
for i := range bytes { if val.CanAddr() {
bytes[i] = val.Index(i).Interface().(byte) bytes = val.Slice(0, val.Len()).Bytes()
} else {
bytes = make([]byte, val.Len())
reflect.Copy(reflect.ValueOf(bytes), val)
} }
Escape(p, bytes) EscapeText(p, bytes)
case reflect.Slice: case reflect.Slice:
// will be []byte // will be []byte
Escape(p, val.Bytes()) EscapeText(p, val.Bytes())
default: default:
return &UnsupportedTypeError{typ} return &UnsupportedTypeError{typ}
} }
...@@ -298,10 +303,14 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { ...@@ -298,10 +303,14 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
case reflect.Bool: case reflect.Bool:
Escape(p, strconv.AppendBool(scratch[:0], vf.Bool())) Escape(p, strconv.AppendBool(scratch[:0], vf.Bool()))
case reflect.String: case reflect.String:
Escape(p, []byte(vf.String())) if err := EscapeText(p, []byte(vf.String())); err != nil {
return err
}
case reflect.Slice: case reflect.Slice:
if elem, ok := vf.Interface().([]byte); ok { if elem, ok := vf.Interface().([]byte); ok {
Escape(p, elem) if err := EscapeText(p, elem); err != nil {
return err
}
} }
case reflect.Struct: case reflect.Struct:
if vf.Type() == timeType { if vf.Type() == timeType {
......
...@@ -965,6 +965,16 @@ func TestMarshalWriteErrors(t *testing.T) { ...@@ -965,6 +965,16 @@ func TestMarshalWriteErrors(t *testing.T) {
} }
} }
func TestMarshalWriteIOErrors(t *testing.T) {
enc := NewEncoder(errWriter{})
expectErr := "unwritable"
err := enc.Encode(&Passenger{})
if err == nil || err.Error() != expectErr {
t.Errorf("EscapeTest = [error] %v, want %v", err, expectErr)
}
}
func BenchmarkMarshal(b *testing.B) { func BenchmarkMarshal(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
Marshal(atomValue) Marshal(atomValue)
......
...@@ -1720,9 +1720,9 @@ var ( ...@@ -1720,9 +1720,9 @@ var (
esc_cr = []byte("&#xD;") esc_cr = []byte("&#xD;")
) )
// Escape writes to w the properly escaped XML equivalent // EscapeText writes to w the properly escaped XML equivalent
// of the plain text data s. // of the plain text data s.
func Escape(w io.Writer, s []byte) { func EscapeText(w io.Writer, s []byte) error {
var esc []byte var esc []byte
last := 0 last := 0
for i, c := range s { for i, c := range s {
...@@ -1746,11 +1746,25 @@ func Escape(w io.Writer, s []byte) { ...@@ -1746,11 +1746,25 @@ func Escape(w io.Writer, s []byte) {
default: default:
continue continue
} }
w.Write(s[last:i]) if _, err := w.Write(s[last:i]); err != nil {
w.Write(esc) return err
}
if _, err := w.Write(esc); err != nil {
return err
}
last = i + 1 last = i + 1
} }
w.Write(s[last:]) if _, err := w.Write(s[last:]); err != nil {
return err
}
return nil
}
// Escape is like EscapeText but omits the error return value.
// It is provided for backwards compatibility with Go 1.0.
// Code targeting Go 1.1 or later should use EscapeText.
func Escape(w io.Writer, s []byte) {
EscapeText(w, s)
} }
// procInstEncoding parses the `encoding="..."` or `encoding='...'` // procInstEncoding parses the `encoding="..."` or `encoding='...'`
......
...@@ -689,3 +689,17 @@ func TestDirectivesWithComments(t *testing.T) { ...@@ -689,3 +689,17 @@ func TestDirectivesWithComments(t *testing.T) {
} }
} }
} }
// Writer whose Write method always returns an error.
type errWriter struct{}
func (errWriter) Write(p []byte) (n int, err error) { return 0, fmt.Errorf("unwritable") }
func TestEscapeTextIOErrors(t *testing.T) {
expectErr := "unwritable"
err := EscapeText(errWriter{}, []byte{'A'})
if err == nil || err.Error() != expectErr {
t.Errorf("EscapeTest = [error] %v, want %v", err, expectErr)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment