Commit 79d06d72 authored by Nigel Tao's avatar Nigel Tao

encoding/base32: don't panic when decoding "AAAA==".

Edit encoding/base64's internals and tests to match encoding/base32.

Properly handling line breaks in padding is left for another CL.

R=dsymonds
CC=golang-dev
https://golang.org/cl/7693044
parent 1b3c969a
...@@ -236,7 +236,6 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { ...@@ -236,7 +236,6 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
var dbuf [8]byte var dbuf [8]byte
dlen := 8 dlen := 8
// do the top bytes contain any data?
for j := 0; j < 8; { for j := 0; j < 8; {
if len(src) == 0 { if len(src) == 0 {
return n, false, CorruptInputError(len(osrc) - len(src) - j) return n, false, CorruptInputError(len(osrc) - len(src) - j)
...@@ -248,15 +247,26 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { ...@@ -248,15 +247,26 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
continue continue
} }
if in == '=' && j >= 2 && len(src) < 8 { if in == '=' && j >= 2 && len(src) < 8 {
// We've reached the end and there's // We've reached the end and there's padding
// padding, the rest should be padded if len(src)+j < 8-1 {
for k := 0; k < 8-j-1; k++ { // not enough padding
return n, false, CorruptInputError(len(osrc))
}
for k := 0; k < 8-1-j; k++ {
if len(src) > k && src[k] != '=' { if len(src) > k && src[k] != '=' {
// incorrect padding
return n, false, CorruptInputError(len(osrc) - len(src) + k - 1) return n, false, CorruptInputError(len(osrc) - len(src) + k - 1)
} }
} }
dlen = j dlen, end = j, true
end = true // 7, 5 and 2 are not valid padding lengths, and so 1, 3 and 6 are not
// valid dlen values. See RFC 4648 Section 6 "Base 32 Encoding" listing
// the five valid padding lengths, and Section 9 "Illustrations and
// Examples" for an illustration for how the the 1st, 3rd and 6th base32
// src bytes do not yield enough information to decode a dst byte.
if dlen == 1 || dlen == 3 || dlen == 6 {
return n, false, CorruptInputError(len(osrc) - len(src) - 1)
}
break break
} }
dbuf[j] = enc.decodeMap[in] dbuf[j] = enc.decodeMap[in]
...@@ -269,16 +279,16 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { ...@@ -269,16 +279,16 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
// Pack 8x 5-bit source blocks into 5 byte destination // Pack 8x 5-bit source blocks into 5 byte destination
// quantum // quantum
switch dlen { switch dlen {
case 7, 8: case 8:
dst[4] = dbuf[6]<<5 | dbuf[7] dst[4] = dbuf[6]<<5 | dbuf[7]
fallthrough fallthrough
case 6, 5: case 7:
dst[3] = dbuf[4]<<7 | dbuf[5]<<2 | dbuf[6]>>3 dst[3] = dbuf[4]<<7 | dbuf[5]<<2 | dbuf[6]>>3
fallthrough fallthrough
case 4: case 5:
dst[2] = dbuf[3]<<4 | dbuf[4]>>1 dst[2] = dbuf[3]<<4 | dbuf[4]>>1
fallthrough fallthrough
case 3: case 4:
dst[1] = dbuf[1]<<6 | dbuf[2]<<1 | dbuf[3]>>4 dst[1] = dbuf[1]<<6 | dbuf[2]<<1 | dbuf[3]>>4
fallthrough fallthrough
case 2: case 2:
...@@ -288,11 +298,11 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { ...@@ -288,11 +298,11 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
switch dlen { switch dlen {
case 2: case 2:
n += 1 n += 1
case 3, 4: case 4:
n += 2 n += 2
case 5: case 5:
n += 3 n += 3
case 6, 7: case 7:
n += 4 n += 4
case 8: case 8:
n += 5 n += 5
......
...@@ -137,27 +137,48 @@ func TestDecoderBuffering(t *testing.T) { ...@@ -137,27 +137,48 @@ func TestDecoderBuffering(t *testing.T) {
} }
func TestDecodeCorrupt(t *testing.T) { func TestDecodeCorrupt(t *testing.T) {
type corrupt struct { testCases := []struct {
e string input string
p int offset int // -1 means no corruption.
} }{
examples := []corrupt{ {"", -1},
{"!!!!", 0}, {"!!!!", 0},
{"x===", 0}, {"x===", 0},
{"AA=A====", 2}, {"AA=A====", 2},
{"AAA=AAAA", 3}, {"AAA=AAAA", 3},
{"MMMMMMMMM", 8}, {"MMMMMMMMM", 8},
{"MMMMMM", 0}, {"MMMMMM", 0},
} {"A=", 1},
{"AA=", 3},
for _, e := range examples { {"AA==", 4},
dbuf := make([]byte, StdEncoding.DecodedLen(len(e.e))) {"AA===", 5},
_, err := StdEncoding.Decode(dbuf, []byte(e.e)) {"AAAA=", 5},
{"AAAA==", 6},
{"AAAAA=", 6},
{"AAAAA==", 7},
{"A=======", 1},
{"AA======", -1},
{"AAA=====", 3},
{"AAAA====", -1},
{"AAAAA===", -1},
{"AAAAAA==", 6},
{"AAAAAAA=", -1},
{"AAAAAAAA", -1},
}
for _, tc := range testCases {
dbuf := make([]byte, StdEncoding.DecodedLen(len(tc.input)))
_, err := StdEncoding.Decode(dbuf, []byte(tc.input))
if tc.offset == -1 {
if err != nil {
t.Error("Decoder wrongly detected coruption in", tc.input)
}
continue
}
switch err := err.(type) { switch err := err.(type) {
case CorruptInputError: case CorruptInputError:
testEqual(t, "Corruption in %q at offset %v, want %v", e.e, int(err), e.p) testEqual(t, "Corruption in %q at offset %v, want %v", tc.input, int(err), tc.offset)
default: default:
t.Error("Decoder failed to detect corruption in", e) t.Error("Decoder failed to detect corruption in", tc)
} }
} }
} }
......
...@@ -227,9 +227,8 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { ...@@ -227,9 +227,8 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
continue continue
} }
if in == '=' && j >= 2 && len(src) < 4 { if in == '=' && j >= 2 && len(src) < 4 {
// We've reached the end and there's // We've reached the end and there's padding
// padding if len(src)+j < 4-1 {
if len(src) == 0 && j == 2 {
// not enough padding // not enough padding
return n, false, CorruptInputError(len(osrc)) return n, false, CorruptInputError(len(osrc))
} }
...@@ -237,8 +236,7 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { ...@@ -237,8 +236,7 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
// incorrect padding // incorrect padding
return n, false, CorruptInputError(len(osrc) - len(src) - 1) return n, false, CorruptInputError(len(osrc) - len(src) - 1)
} }
dlen = j dlen, end = j, true
end = true
break break
} }
dbuf[j] = enc.decodeMap[in] dbuf[j] = enc.decodeMap[in]
......
...@@ -142,11 +142,11 @@ func TestDecoderBuffering(t *testing.T) { ...@@ -142,11 +142,11 @@ func TestDecoderBuffering(t *testing.T) {
} }
func TestDecodeCorrupt(t *testing.T) { func TestDecodeCorrupt(t *testing.T) {
type corrupt struct { testCases := []struct {
e string input string
p int offset int // -1 means no corruption.
} }{
examples := []corrupt{ {"", -1},
{"!!!!", 0}, {"!!!!", 0},
{"x===", 1}, {"x===", 1},
{"AA=A", 2}, {"AA=A", 2},
...@@ -154,18 +154,27 @@ func TestDecodeCorrupt(t *testing.T) { ...@@ -154,18 +154,27 @@ func TestDecodeCorrupt(t *testing.T) {
{"AAAAA", 4}, {"AAAAA", 4},
{"AAAAAA", 4}, {"AAAAAA", 4},
{"A=", 1}, {"A=", 1},
{"A==", 1},
{"AA=", 3}, {"AA=", 3},
{"AA==", -1},
{"AAA=", -1},
{"AAAA", -1},
{"AAAAAA=", 7}, {"AAAAAA=", 7},
} }
for _, tc := range testCases {
for _, e := range examples { dbuf := make([]byte, StdEncoding.DecodedLen(len(tc.input)))
dbuf := make([]byte, StdEncoding.DecodedLen(len(e.e))) _, err := StdEncoding.Decode(dbuf, []byte(tc.input))
_, err := StdEncoding.Decode(dbuf, []byte(e.e)) if tc.offset == -1 {
if err != nil {
t.Error("Decoder wrongly detected coruption in", tc.input)
}
continue
}
switch err := err.(type) { switch err := err.(type) {
case CorruptInputError: case CorruptInputError:
testEqual(t, "Corruption in %q at offset %v, want %v", e.e, int(err), e.p) testEqual(t, "Corruption in %q at offset %v, want %v", tc.input, int(err), tc.offset)
default: default:
t.Error("Decoder failed to detect corruption in", e) t.Error("Decoder failed to detect corruption in", tc)
} }
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment