Commit 5a966cf2 authored by Joe Tsai's avatar Joe Tsai Committed by Brad Fitzpatrick

compress/zlib: make errors persistent

Ensure that all errors (including io.EOF) are persistent across method
calls on zlib.Reader. Furthermore, ensure that these persistent errors
are properly cleared when Reset is called.

Fixes #14675

Change-Id: I15a20c7e25dc38219e7e0ff255d1ba775a86bb47
Reviewed-on: https://go-review.googlesource.com/20292Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 2a68c6c2
...@@ -84,19 +84,17 @@ func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, error) { ...@@ -84,19 +84,17 @@ func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, error) {
return z, nil return z, nil
} }
func (z *reader) Read(p []byte) (n int, err error) { func (z *reader) Read(p []byte) (int, error) {
if z.err != nil { if z.err != nil {
return 0, z.err return 0, z.err
} }
if len(p) == 0 {
return 0, nil
}
n, err = z.decompressor.Read(p) var n int
n, z.err = z.decompressor.Read(p)
z.digest.Write(p[0:n]) z.digest.Write(p[0:n])
if n != 0 || err != io.EOF { if z.err != io.EOF {
z.err = err // In the normal case we return here.
return return n, z.err
} }
// Finished file; check checksum. // Finished file; check checksum.
...@@ -105,20 +103,20 @@ func (z *reader) Read(p []byte) (n int, err error) { ...@@ -105,20 +103,20 @@ func (z *reader) Read(p []byte) (n int, err error) {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
z.err = err z.err = err
return 0, err return n, z.err
} }
// ZLIB (RFC 1950) is big-endian, unlike GZIP (RFC 1952). // ZLIB (RFC 1950) is big-endian, unlike GZIP (RFC 1952).
checksum := uint32(z.scratch[0])<<24 | uint32(z.scratch[1])<<16 | uint32(z.scratch[2])<<8 | uint32(z.scratch[3]) checksum := uint32(z.scratch[0])<<24 | uint32(z.scratch[1])<<16 | uint32(z.scratch[2])<<8 | uint32(z.scratch[3])
if checksum != z.digest.Sum32() { if checksum != z.digest.Sum32() {
z.err = ErrChecksum z.err = ErrChecksum
return 0, z.err return n, z.err
} }
return return n, io.EOF
} }
// Calling Close does not close the wrapped io.Reader originally passed to NewReader. // Calling Close does not close the wrapped io.Reader originally passed to NewReader.
func (z *reader) Close() error { func (z *reader) Close() error {
if z.err != nil { if z.err != nil && z.err != io.EOF {
return z.err return z.err
} }
z.err = z.decompressor.Close() z.err = z.decompressor.Close()
...@@ -126,36 +124,42 @@ func (z *reader) Close() error { ...@@ -126,36 +124,42 @@ func (z *reader) Close() error {
} }
func (z *reader) Reset(r io.Reader, dict []byte) error { func (z *reader) Reset(r io.Reader, dict []byte) error {
*z = reader{decompressor: z.decompressor}
if fr, ok := r.(flate.Reader); ok { if fr, ok := r.(flate.Reader); ok {
z.r = fr z.r = fr
} else { } else {
z.r = bufio.NewReader(r) z.r = bufio.NewReader(r)
} }
_, err := io.ReadFull(z.r, z.scratch[0:2])
if err != nil { // Read the header (RFC 1950 section 2.2.).
if err == io.EOF { _, z.err = io.ReadFull(z.r, z.scratch[0:2])
err = io.ErrUnexpectedEOF if z.err != nil {
if z.err == io.EOF {
z.err = io.ErrUnexpectedEOF
} }
return err return z.err
} }
h := uint(z.scratch[0])<<8 | uint(z.scratch[1]) h := uint(z.scratch[0])<<8 | uint(z.scratch[1])
if (z.scratch[0]&0x0f != zlibDeflate) || (h%31 != 0) { if (z.scratch[0]&0x0f != zlibDeflate) || (h%31 != 0) {
return ErrHeader z.err = ErrHeader
return z.err
} }
haveDict := z.scratch[1]&0x20 != 0 haveDict := z.scratch[1]&0x20 != 0
if haveDict { if haveDict {
_, err = io.ReadFull(z.r, z.scratch[0:4]) _, z.err = io.ReadFull(z.r, z.scratch[0:4])
if err != nil { if z.err != nil {
if err == io.EOF { if z.err == io.EOF {
err = io.ErrUnexpectedEOF z.err = io.ErrUnexpectedEOF
} }
return err return z.err
} }
checksum := uint32(z.scratch[0])<<24 | uint32(z.scratch[1])<<16 | uint32(z.scratch[2])<<8 | uint32(z.scratch[3]) checksum := uint32(z.scratch[0])<<24 | uint32(z.scratch[1])<<16 | uint32(z.scratch[2])<<8 | uint32(z.scratch[3])
if checksum != adler32.Checksum(dict) { if checksum != adler32.Checksum(dict) {
return ErrDictionary z.err = ErrDictionary
return z.err
} }
} }
if z.decompressor == nil { if z.decompressor == nil {
if haveDict { if haveDict {
z.decompressor = flate.NewReaderDict(z.r, dict) z.decompressor = flate.NewReaderDict(z.r, dict)
......
...@@ -127,16 +127,18 @@ func TestDecompressor(t *testing.T) { ...@@ -127,16 +127,18 @@ func TestDecompressor(t *testing.T) {
b := new(bytes.Buffer) b := new(bytes.Buffer)
for _, tt := range zlibTests { for _, tt := range zlibTests {
in := bytes.NewReader(tt.compressed) in := bytes.NewReader(tt.compressed)
zlib, err := NewReaderDict(in, tt.dict) zr, err := NewReaderDict(in, tt.dict)
if err != nil { if err != nil {
if err != tt.err { if err != tt.err {
t.Errorf("%s: NewReader: %s", tt.desc, err) t.Errorf("%s: NewReader: %s", tt.desc, err)
} }
continue continue
} }
defer zlib.Close() defer zr.Close()
// Read and verify correctness of data.
b.Reset() b.Reset()
n, err := io.Copy(b, zlib) n, err := io.Copy(b, zr)
if err != nil { if err != nil {
if err != tt.err { if err != tt.err {
t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err) t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err)
...@@ -147,5 +149,13 @@ func TestDecompressor(t *testing.T) { ...@@ -147,5 +149,13 @@ func TestDecompressor(t *testing.T) {
if s != tt.raw { if s != tt.raw {
t.Errorf("%s: got %d-byte %q want %d-byte %q", tt.desc, n, s, len(tt.raw), tt.raw) t.Errorf("%s: got %d-byte %q want %d-byte %q", tt.desc, n, s, len(tt.raw), tt.raw)
} }
// Check for sticky errors.
if n, err := zr.Read([]byte{0}); n != 0 || err != io.EOF {
t.Errorf("%s: Read() = (%d, %v), want (0, io.EOF)", tt.desc, n, err)
}
if err := zr.Close(); err != nil {
t.Errorf("%s: Close() = %v, want nil", tt.desc, err)
}
} }
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment