Commit aea22438 authored by Robert Griesemer's avatar Robert Griesemer

math/big: more tests, documentation for Flot gob marshalling

Follow-up to https://golang.org/cl/21755.
This turned out to be a bit more than just a few nits
as originally expected in that CL.

1) The actual mantissa may be shorter than required for the
   given precision (because of trailing 0's): no need to
   allocate space for it (and transmit 0's). This can save
   a lot of space when the precision is high: E.g., for
   prec == 1000, 16 words or 128 bytes are required at the
   most, but if the actual number is short, it may be much
   less (for the test cases present, it's significantly less).

2) The actual mantissa may be longer than the number of
   words required for the given precision: make sure to
   not overflow when encoding in bytes.

3) Add more documentation.

4) Add more tests.

Change-Id: I9f40c408cfdd9183a8e81076d2f7d6c75e7a00e9
Reviewed-on: https://go-review.googlesource.com/22324Reviewed-by: default avatarAlan Donovan <adonovan@google.com>
parent 60fd32a4
...@@ -15,13 +15,29 @@ import ( ...@@ -15,13 +15,29 @@ import (
const floatGobVersion byte = 1 const floatGobVersion byte = 1
// GobEncode implements the gob.GobEncoder interface. // GobEncode implements the gob.GobEncoder interface.
// The Float value and all its attributes (precision,
// rounding mode, accuracy) are marshalled.
func (x *Float) GobEncode() ([]byte, error) { func (x *Float) GobEncode() ([]byte, error) {
if x == nil { if x == nil {
return nil, nil return nil, nil
} }
// determine max. space (bytes) required for encoding
sz := 1 + 1 + 4 // version + mode|acc|form|neg (3+2+2+1bit) + prec sz := 1 + 1 + 4 // version + mode|acc|form|neg (3+2+2+1bit) + prec
n := 0 // number of mantissa words
if x.form == finite { if x.form == finite {
sz += 4 + int((x.prec+(_W-1))/_W)*_S // exp + mant // add space for mantissa and exponent
n = int((x.prec + (_W - 1)) / _W) // required mantissa length in words for given precision
// actual mantissa slice could be shorter (trailing 0's) or longer (unused bits):
// - if shorter, only encode the words present
// - if longer, cut off unused words when encoding in bytes
// (in practice, this should never happen since rounding
// takes care of it, but be safe and do it always)
if len(x.mant) < n {
n = len(x.mant)
}
// len(x.mant) >= n
sz += 4 + n*_S // exp + mant
} }
buf := make([]byte, sz) buf := make([]byte, sz)
...@@ -32,14 +48,19 @@ func (x *Float) GobEncode() ([]byte, error) { ...@@ -32,14 +48,19 @@ func (x *Float) GobEncode() ([]byte, error) {
} }
buf[1] = b buf[1] = b
binary.BigEndian.PutUint32(buf[2:], x.prec) binary.BigEndian.PutUint32(buf[2:], x.prec)
if x.form == finite { if x.form == finite {
binary.BigEndian.PutUint32(buf[6:], uint32(x.exp)) binary.BigEndian.PutUint32(buf[6:], uint32(x.exp))
x.mant.bytes(buf[10:]) x.mant[len(x.mant)-n:].bytes(buf[10:]) // cut off unused trailing words
} }
return buf, nil return buf, nil
} }
// GobDecode implements the gob.GobDecoder interface. // GobDecode implements the gob.GobDecoder interface.
// The result is rounded per the precision and rounding mode of
// z unless z's precision is 0, in which case z is set exactly
// to the decoded value.
func (z *Float) GobDecode(buf []byte) error { func (z *Float) GobDecode(buf []byte) error {
if len(buf) == 0 { if len(buf) == 0 {
// Other side sent a nil or default value. // Other side sent a nil or default value.
...@@ -51,13 +72,14 @@ func (z *Float) GobDecode(buf []byte) error { ...@@ -51,13 +72,14 @@ func (z *Float) GobDecode(buf []byte) error {
return fmt.Errorf("Float.GobDecode: encoding version %d not supported", buf[0]) return fmt.Errorf("Float.GobDecode: encoding version %d not supported", buf[0])
} }
oldPrec := z.prec
oldMode := z.mode
b := buf[1] b := buf[1]
z.mode = RoundingMode((b >> 5) & 7) z.mode = RoundingMode((b >> 5) & 7)
z.acc = Accuracy((b>>3)&3) - 1 z.acc = Accuracy((b>>3)&3) - 1
z.form = form((b >> 1) & 3) z.form = form((b >> 1) & 3)
z.neg = b&1 != 0 z.neg = b&1 != 0
oldPrec := uint(z.prec)
z.prec = binary.BigEndian.Uint32(buf[2:]) z.prec = binary.BigEndian.Uint32(buf[2:])
if z.form == finite { if z.form == finite {
...@@ -66,8 +88,10 @@ func (z *Float) GobDecode(buf []byte) error { ...@@ -66,8 +88,10 @@ func (z *Float) GobDecode(buf []byte) error {
} }
if oldPrec != 0 { if oldPrec != 0 {
z.SetPrec(oldPrec) z.mode = oldMode
z.SetPrec(uint(oldPrec))
} }
return nil return nil
} }
......
...@@ -28,43 +28,60 @@ var floatVals = []string{ ...@@ -28,43 +28,60 @@ var floatVals = []string{
func TestFloatGobEncoding(t *testing.T) { func TestFloatGobEncoding(t *testing.T) {
var medium bytes.Buffer var medium bytes.Buffer
enc := gob.NewEncoder(&medium)
dec := gob.NewDecoder(&medium)
for _, test := range floatVals { for _, test := range floatVals {
for _, sign := range []string{"", "+", "-"} { for _, sign := range []string{"", "+", "-"} {
for _, prec := range []uint{0, 1, 2, 10, 53, 64, 100, 1000} { for _, prec := range []uint{0, 1, 2, 10, 53, 64, 100, 1000} {
medium.Reset() // empty buffer for each test case (in case of failures) for _, mode := range []RoundingMode{ToNearestEven, ToNearestAway, ToZero, AwayFromZero, ToNegativeInf, ToPositiveInf} {
enc := gob.NewEncoder(&medium) medium.Reset() // empty buffer for each test case (in case of failures)
dec := gob.NewDecoder(&medium) x := sign + test
x := sign + test
var tx Float
_, _, err := tx.SetPrec(prec).Parse(x, 0)
if err != nil {
t.Errorf("parsing of %s (prec = %d) failed (invalid test case): %v", x, prec, err)
continue
}
tx.SetMode(ToPositiveInf)
if err := enc.Encode(&tx); err != nil {
t.Errorf("encoding of %v (prec = %d) failed: %v", &tx, prec, err)
continue
}
var rx Float var tx Float
if err := dec.Decode(&rx); err != nil { _, _, err := tx.SetPrec(prec).SetMode(mode).Parse(x, 0)
t.Errorf("decoding of %v (prec = %d) failed: %v", &tx, prec, err) if err != nil {
continue t.Errorf("parsing of %s (%dbits, %v) failed (invalid test case): %v", x, prec, mode, err)
} continue
}
if rx.Cmp(&tx) != 0 { // If tx was set to prec == 0, tx.Parse(x, 0) assumes precision 64. Correct it.
t.Errorf("transmission of %s failed: got %s want %s", x, rx.String(), tx.String()) if prec == 0 {
continue tx.SetPrec(0)
} }
if err := enc.Encode(&tx); err != nil {
t.Errorf("encoding of %v (%dbits, %v) failed: %v", &tx, prec, mode, err)
continue
}
var rx Float
if err := dec.Decode(&rx); err != nil {
t.Errorf("decoding of %v (%dbits, %v) failed: %v", &tx, prec, mode, err)
continue
}
if rx.Cmp(&tx) != 0 {
t.Errorf("transmission of %s failed: got %s want %s", x, rx.String(), tx.String())
continue
}
if rx.Prec() != prec {
t.Errorf("transmission of %s's prec failed: got %d want %d", x, rx.Prec(), prec)
}
if rx.Mode() != mode {
t.Errorf("transmission of %s's mode failed: got %s want %s", x, rx.Mode(), mode)
}
if rx.Mode() != ToPositiveInf { if rx.Acc() != tx.Acc() {
t.Errorf("transmission of %s's mode failed: got %s want %s", x, rx.Mode(), ToPositiveInf) t.Errorf("transmission of %s's accuracy failed: got %s want %s", x, rx.Acc(), tx.Acc())
}
} }
} }
} }
} }
} }
func TestFloatCorruptGob(t *testing.T) { func TestFloatCorruptGob(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
tx := NewFloat(4 / 3).SetPrec(1000).SetMode(ToPositiveInf) tx := NewFloat(4 / 3).SetPrec(1000).SetMode(ToPositiveInf)
...@@ -72,20 +89,22 @@ func TestFloatCorruptGob(t *testing.T) { ...@@ -72,20 +89,22 @@ func TestFloatCorruptGob(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
b := buf.Bytes() b := buf.Bytes()
var rx Float var rx Float
if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&rx); err != nil { if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&rx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
var rx2 Float
if err := gob.NewDecoder(bytes.NewReader(b[:10])).Decode(&rx2); err != io.ErrUnexpectedEOF { if err := gob.NewDecoder(bytes.NewReader(b[:10])).Decode(&rx); err != io.ErrUnexpectedEOF {
t.Errorf("expected io.ErrUnexpectedEOF, got %v", err) t.Errorf("got %v want EOF", err)
} }
b[1] = 0 b[1] = 0
if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&rx); err == nil { if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&rx); err == nil {
t.Fatal("expected a version error, got nil") t.Fatal("got nil want version error")
} }
} }
func TestFloatJSONEncoding(t *testing.T) { func TestFloatJSONEncoding(t *testing.T) {
for _, test := range floatVals { for _, test := range floatVals {
for _, sign := range []string{"", "+", "-"} { for _, sign := range []string{"", "+", "-"} {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment