Commit 8206bafb authored by Russ Cox's avatar Russ Cox

asn1: make interface consistent with json

Replace Marshal with MarshalToMemory
(no one was using old Marshal anyway).

Swap arguments to Unmarshal.

Fixes #1133.

R=agl1
CC=golang-dev
https://golang.org/cl/2249045
parent 1d315a8a
...@@ -775,7 +775,7 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) { ...@@ -775,7 +775,7 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
// //
// Other ASN.1 types are not supported; if it encounters them, // Other ASN.1 types are not supported; if it encounters them,
// Unmarshal returns a parse error. // Unmarshal returns a parse error.
func Unmarshal(val interface{}, b []byte) (rest []byte, err os.Error) { func Unmarshal(b []byte, val interface{}) (rest []byte, err os.Error) {
v := reflect.NewValue(val).(*reflect.PtrValue).Elem() v := reflect.NewValue(val).(*reflect.PtrValue).Elem()
offset, err := parseField(v, b, 0, fieldParameters{}) offset, err := parseField(v, b, 0, fieldParameters{})
if err != nil { if err != nil {
......
...@@ -312,7 +312,7 @@ func TestUnmarshal(t *testing.T) { ...@@ -312,7 +312,7 @@ func TestUnmarshal(t *testing.T) {
zv := reflect.MakeZero(pv.Type().(*reflect.PtrType).Elem()) zv := reflect.MakeZero(pv.Type().(*reflect.PtrType).Elem())
pv.(*reflect.PtrValue).PointTo(zv) pv.(*reflect.PtrValue).PointTo(zv)
val := pv.Interface() val := pv.Interface()
_, err := Unmarshal(val, test.in) _, err := Unmarshal(test.in, val)
if err != nil { if err != nil {
t.Errorf("Unmarshal failed at index %d %v", i, err) t.Errorf("Unmarshal failed at index %d %v", i, err)
} }
...@@ -363,7 +363,7 @@ type PublicKeyInfo struct { ...@@ -363,7 +363,7 @@ type PublicKeyInfo struct {
func TestCertificate(t *testing.T) { func TestCertificate(t *testing.T) {
// This is a minimal, self-signed certificate that should parse correctly. // This is a minimal, self-signed certificate that should parse correctly.
var cert Certificate var cert Certificate
if _, err := Unmarshal(&cert, derEncodedSelfSignedCertBytes); err != nil { if _, err := Unmarshal(derEncodedSelfSignedCertBytes, &cert); err != nil {
t.Errorf("Unmarshal failed: %v", err) t.Errorf("Unmarshal failed: %v", err)
} }
if !reflect.DeepEqual(cert, derEncodedSelfSignedCert) { if !reflect.DeepEqual(cert, derEncodedSelfSignedCert) {
...@@ -376,7 +376,7 @@ func TestCertificateWithNUL(t *testing.T) { ...@@ -376,7 +376,7 @@ func TestCertificateWithNUL(t *testing.T) {
// NUL isn't a permitted character in a PrintableString. // NUL isn't a permitted character in a PrintableString.
var cert Certificate var cert Certificate
if _, err := Unmarshal(&cert, derEncodedPaypalNULCertBytes); err == nil { if _, err := Unmarshal(derEncodedPaypalNULCertBytes, &cert); err == nil {
t.Error("Unmarshal succeeded, should not have") t.Error("Unmarshal succeeded, should not have")
} }
} }
...@@ -390,7 +390,7 @@ func TestRawStructs(t *testing.T) { ...@@ -390,7 +390,7 @@ func TestRawStructs(t *testing.T) {
var s rawStructTest var s rawStructTest
input := []byte{0x30, 0x03, 0x02, 0x01, 0x50} input := []byte{0x30, 0x03, 0x02, 0x01, 0x50}
rest, err := Unmarshal(&s, input) rest, err := Unmarshal(input, &s)
if len(rest) != 0 { if len(rest) != 0 {
t.Errorf("incomplete parse: %x", rest) t.Errorf("incomplete parse: %x", rest)
return return
......
...@@ -468,25 +468,15 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) ...@@ -468,25 +468,15 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
return nil return nil
} }
// Marshal serialises val as an ASN.1 structure and writes the result to out. // Marshal returns the ASN.1 encoding of val.
// In the case of an error, no output is produced. func Marshal(val interface{}) ([]byte, os.Error) {
func Marshal(out io.Writer, val interface{}) os.Error { var out bytes.Buffer
v := reflect.NewValue(val) v := reflect.NewValue(val)
f := newForkableWriter() f := newForkableWriter()
err := marshalField(f, v, fieldParameters{}) err := marshalField(f, v, fieldParameters{})
if err != nil { if err != nil {
return err
}
_, err = f.writeTo(out)
return err
}
// MarshalToMemory performs the same actions as Marshal, but returns the result
// as a byte slice.
func MarshalToMemory(val interface{}) ([]byte, os.Error) {
var out bytes.Buffer
if err := Marshal(&out, val); err != nil {
return nil, err return nil, err
} }
_, err = f.writeTo(&out)
return out.Bytes(), nil return out.Bytes(), nil
} }
...@@ -88,14 +88,13 @@ var marshalTests = []marshalTest{ ...@@ -88,14 +88,13 @@ var marshalTests = []marshalTest{
func TestMarshal(t *testing.T) { func TestMarshal(t *testing.T) {
for i, test := range marshalTests { for i, test := range marshalTests {
buf := bytes.NewBuffer(nil) data, err := Marshal(test.in)
err := Marshal(buf, test.in)
if err != nil { if err != nil {
t.Errorf("#%d failed: %s", i, err) t.Errorf("#%d failed: %s", i, err)
} }
out, _ := hex.DecodeString(test.out) out, _ := hex.DecodeString(test.out)
if bytes.Compare(out, buf.Bytes()) != 0 { if bytes.Compare(out, data) != 0 {
t.Errorf("#%d got: %x want %x", i, buf.Bytes(), out) t.Errorf("#%d got: %x want %x", i, data, out)
} }
} }
} }
...@@ -36,7 +36,7 @@ func rawValueIsInteger(raw *asn1.RawValue) bool { ...@@ -36,7 +36,7 @@ func rawValueIsInteger(raw *asn1.RawValue) bool {
// ParsePKCS1PrivateKey returns an RSA private key from its ASN.1 PKCS#1 DER encoded form. // ParsePKCS1PrivateKey returns an RSA private key from its ASN.1 PKCS#1 DER encoded form.
func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) { func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) {
var priv pkcs1PrivateKey var priv pkcs1PrivateKey
rest, err := asn1.Unmarshal(&priv, der) rest, err := asn1.Unmarshal(der, &priv)
if len(rest) > 0 { if len(rest) > 0 {
err = asn1.SyntaxError{"trailing data"} err = asn1.SyntaxError{"trailing data"}
return return
...@@ -81,7 +81,7 @@ func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte { ...@@ -81,7 +81,7 @@ func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte {
Q: asn1.RawValue{Tag: 2, Bytes: key.Q.Bytes()}, Q: asn1.RawValue{Tag: 2, Bytes: key.Q.Bytes()},
} }
b, _ := asn1.MarshalToMemory(priv) b, _ := asn1.Marshal(priv)
return b return b
} }
...@@ -480,7 +480,7 @@ func parsePublicKey(algo PublicKeyAlgorithm, asn1Data []byte) (interface{}, os.E ...@@ -480,7 +480,7 @@ func parsePublicKey(algo PublicKeyAlgorithm, asn1Data []byte) (interface{}, os.E
switch algo { switch algo {
case RSA: case RSA:
p := new(rsaPublicKey) p := new(rsaPublicKey)
_, err := asn1.Unmarshal(p, asn1Data) _, err := asn1.Unmarshal(asn1Data, p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -543,7 +543,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) { ...@@ -543,7 +543,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) {
case 15: case 15:
// RFC 5280, 4.2.1.3 // RFC 5280, 4.2.1.3
var usageBits asn1.BitString var usageBits asn1.BitString
_, err := asn1.Unmarshal(&usageBits, e.Value) _, err := asn1.Unmarshal(e.Value, &usageBits)
if err == nil { if err == nil {
var usage int var usage int
...@@ -558,7 +558,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) { ...@@ -558,7 +558,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) {
case 19: case 19:
// RFC 5280, 4.2.1.9 // RFC 5280, 4.2.1.9
var constriants basicConstraints var constriants basicConstraints
_, err := asn1.Unmarshal(&constriants, e.Value) _, err := asn1.Unmarshal(e.Value, &constriants)
if err == nil { if err == nil {
out.BasicConstraintsValid = true out.BasicConstraintsValid = true
...@@ -584,7 +584,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) { ...@@ -584,7 +584,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) {
// iPAddress [7] OCTET STRING, // iPAddress [7] OCTET STRING,
// registeredID [8] OBJECT IDENTIFIER } // registeredID [8] OBJECT IDENTIFIER }
var seq asn1.RawValue var seq asn1.RawValue
_, err := asn1.Unmarshal(&seq, e.Value) _, err := asn1.Unmarshal(e.Value, &seq)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -597,7 +597,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) { ...@@ -597,7 +597,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) {
rest := seq.Bytes rest := seq.Bytes
for len(rest) > 0 { for len(rest) > 0 {
var v asn1.RawValue var v asn1.RawValue
rest, err = asn1.Unmarshal(&v, rest) rest, err = asn1.Unmarshal(rest, &v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -620,7 +620,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) { ...@@ -620,7 +620,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) {
case 35: case 35:
// RFC 5280, 4.2.1.1 // RFC 5280, 4.2.1.1
var a authKeyId var a authKeyId
_, err = asn1.Unmarshal(&a, e.Value) _, err = asn1.Unmarshal(e.Value, &a)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -630,7 +630,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) { ...@@ -630,7 +630,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) {
case 14: case 14:
// RFC 5280, 4.2.1.2 // RFC 5280, 4.2.1.2
var keyid []byte var keyid []byte
_, err = asn1.Unmarshal(&keyid, e.Value) _, err = asn1.Unmarshal(e.Value, &keyid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -650,7 +650,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) { ...@@ -650,7 +650,7 @@ func parseCertificate(in *certificate) (*Certificate, os.Error) {
// ParseCertificate parses a single certificate from the given ASN.1 DER data. // ParseCertificate parses a single certificate from the given ASN.1 DER data.
func ParseCertificate(asn1Data []byte) (*Certificate, os.Error) { func ParseCertificate(asn1Data []byte) (*Certificate, os.Error) {
var cert certificate var cert certificate
rest, err := asn1.Unmarshal(&cert, asn1Data) rest, err := asn1.Unmarshal(asn1Data, &cert)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -669,7 +669,7 @@ func ParseCertificates(asn1Data []byte) ([]*Certificate, os.Error) { ...@@ -669,7 +669,7 @@ func ParseCertificates(asn1Data []byte) ([]*Certificate, os.Error) {
for len(asn1Data) > 0 { for len(asn1Data) > 0 {
cert := new(certificate) cert := new(certificate)
var err os.Error var err os.Error
asn1Data, err = asn1.Unmarshal(cert, asn1Data) asn1Data, err = asn1.Unmarshal(asn1Data, cert)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -720,7 +720,7 @@ func buildExtensions(template *Certificate) (ret []extension, err os.Error) { ...@@ -720,7 +720,7 @@ func buildExtensions(template *Certificate) (ret []extension, err os.Error) {
l = 2 l = 2
} }
ret[n].Value, err = asn1.MarshalToMemory(asn1.BitString{Bytes: a[0:l], BitLength: l * 8}) ret[n].Value, err = asn1.Marshal(asn1.BitString{Bytes: a[0:l], BitLength: l * 8})
if err != nil { if err != nil {
return return
} }
...@@ -729,7 +729,7 @@ func buildExtensions(template *Certificate) (ret []extension, err os.Error) { ...@@ -729,7 +729,7 @@ func buildExtensions(template *Certificate) (ret []extension, err os.Error) {
if template.BasicConstraintsValid { if template.BasicConstraintsValid {
ret[n].Id = oidExtensionBasicConstraints ret[n].Id = oidExtensionBasicConstraints
ret[n].Value, err = asn1.MarshalToMemory(basicConstraints{template.IsCA, template.MaxPathLen}) ret[n].Value, err = asn1.Marshal(basicConstraints{template.IsCA, template.MaxPathLen})
ret[n].Critical = true ret[n].Critical = true
if err != nil { if err != nil {
return return
...@@ -739,7 +739,7 @@ func buildExtensions(template *Certificate) (ret []extension, err os.Error) { ...@@ -739,7 +739,7 @@ func buildExtensions(template *Certificate) (ret []extension, err os.Error) {
if len(template.SubjectKeyId) > 0 { if len(template.SubjectKeyId) > 0 {
ret[n].Id = oidExtensionSubjectKeyId ret[n].Id = oidExtensionSubjectKeyId
ret[n].Value, err = asn1.MarshalToMemory(template.SubjectKeyId) ret[n].Value, err = asn1.Marshal(template.SubjectKeyId)
if err != nil { if err != nil {
return return
} }
...@@ -748,7 +748,7 @@ func buildExtensions(template *Certificate) (ret []extension, err os.Error) { ...@@ -748,7 +748,7 @@ func buildExtensions(template *Certificate) (ret []extension, err os.Error) {
if len(template.AuthorityKeyId) > 0 { if len(template.AuthorityKeyId) > 0 {
ret[n].Id = oidExtensionAuthorityKeyId ret[n].Id = oidExtensionAuthorityKeyId
ret[n].Value, err = asn1.MarshalToMemory(authKeyId{template.AuthorityKeyId}) ret[n].Value, err = asn1.Marshal(authKeyId{template.AuthorityKeyId})
if err != nil { if err != nil {
return return
} }
...@@ -761,7 +761,7 @@ func buildExtensions(template *Certificate) (ret []extension, err os.Error) { ...@@ -761,7 +761,7 @@ func buildExtensions(template *Certificate) (ret []extension, err os.Error) {
for i, name := range template.DNSNames { for i, name := range template.DNSNames {
rawValues[i] = asn1.RawValue{Tag: 2, Class: 2, Bytes: []byte(name)} rawValues[i] = asn1.RawValue{Tag: 2, Class: 2, Bytes: []byte(name)}
} }
ret[n].Value, err = asn1.MarshalToMemory(rawValues) ret[n].Value, err = asn1.Marshal(rawValues)
if err != nil { if err != nil {
return return
} }
...@@ -790,7 +790,7 @@ var ( ...@@ -790,7 +790,7 @@ var (
// //
// The returned slice is the certificate in DER encoding. // The returned slice is the certificate in DER encoding.
func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.PublicKey, priv *rsa.PrivateKey) (cert []byte, err os.Error) { func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.PublicKey, priv *rsa.PrivateKey) (cert []byte, err os.Error) {
asn1PublicKey, err := asn1.MarshalToMemory(rsaPublicKey{ asn1PublicKey, err := asn1.Marshal(rsaPublicKey{
N: asn1.RawValue{Tag: 2, Bytes: pub.N.Bytes()}, N: asn1.RawValue{Tag: 2, Bytes: pub.N.Bytes()},
E: pub.E, E: pub.E,
}) })
...@@ -819,7 +819,7 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P ...@@ -819,7 +819,7 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P
Extensions: extensions, Extensions: extensions,
} }
tbsCertContents, err := asn1.MarshalToMemory(c) tbsCertContents, err := asn1.Marshal(c)
if err != nil { if err != nil {
return return
} }
...@@ -835,7 +835,7 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P ...@@ -835,7 +835,7 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P
return return
} }
cert, err = asn1.MarshalToMemory(certificate{ cert, err = asn1.Marshal(certificate{
c, c,
algorithmIdentifier{oidSHA1WithRSA}, algorithmIdentifier{oidSHA1WithRSA},
asn1.BitString{Bytes: signature, BitLength: len(signature) * 8}, asn1.BitString{Bytes: signature, BitLength: len(signature) * 8},
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment