Commit cded21a3 authored by Russ Cox's avatar Russ Cox

changes for more restricted reflect.SetValue

R=golang-dev, r
CC=golang-dev
https://golang.org/cl/4423043
parent 40fccbce
...@@ -276,21 +276,21 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) ...@@ -276,21 +276,21 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value)
return v return v
case reflect.Struct: case reflect.Struct:
v := reflect.Zero(p.Type()) v := reflect.New(p.Type()).Elem()
for i := 0; i < p.NumField(); i++ { for i := 0; i < p.NumField(); i++ {
v.Field(i).Set(subst(m, p.Field(i), pos)) v.Field(i).Set(subst(m, p.Field(i), pos))
} }
return v return v
case reflect.Ptr: case reflect.Ptr:
v := reflect.Zero(p.Type()) v := reflect.New(p.Type()).Elem()
if elem := p.Elem(); elem.IsValid() { if elem := p.Elem(); elem.IsValid() {
v.Set(subst(m, elem, pos).Addr()) v.Set(subst(m, elem, pos).Addr())
} }
return v return v
case reflect.Interface: case reflect.Interface:
v := reflect.Zero(p.Type()) v := reflect.New(p.Type()).Elem()
if elem := p.Elem(); elem.IsValid() { if elem := p.Elem(); elem.IsValid() {
v.Set(subst(m, elem, pos)) v.Set(subst(m, elem, pos))
} }
......
...@@ -267,11 +267,6 @@ func TestParseFieldParameters(t *testing.T) { ...@@ -267,11 +267,6 @@ func TestParseFieldParameters(t *testing.T) {
} }
} }
type unmarshalTest struct {
in []byte
out interface{}
}
type TestObjectIdentifierStruct struct { type TestObjectIdentifierStruct struct {
OID ObjectIdentifier OID ObjectIdentifier
} }
...@@ -290,7 +285,10 @@ type TestElementsAfterString struct { ...@@ -290,7 +285,10 @@ type TestElementsAfterString struct {
A, B int A, B int
} }
var unmarshalTestData []unmarshalTest = []unmarshalTest{ var unmarshalTestData = []struct {
in []byte
out interface{}
}{
{[]byte{0x02, 0x01, 0x42}, newInt(0x42)}, {[]byte{0x02, 0x01, 0x42}, newInt(0x42)},
{[]byte{0x30, 0x08, 0x06, 0x06, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d}, &TestObjectIdentifierStruct{[]int{1, 2, 840, 113549}}}, {[]byte{0x30, 0x08, 0x06, 0x06, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d}, &TestObjectIdentifierStruct{[]int{1, 2, 840, 113549}}},
{[]byte{0x03, 0x04, 0x06, 0x6e, 0x5d, 0xc0}, &BitString{[]byte{110, 93, 192}, 18}}, {[]byte{0x03, 0x04, 0x06, 0x6e, 0x5d, 0xc0}, &BitString{[]byte{110, 93, 192}, 18}},
...@@ -309,9 +307,7 @@ var unmarshalTestData []unmarshalTest = []unmarshalTest{ ...@@ -309,9 +307,7 @@ var unmarshalTestData []unmarshalTest = []unmarshalTest{
func TestUnmarshal(t *testing.T) { func TestUnmarshal(t *testing.T) {
for i, test := range unmarshalTestData { for i, test := range unmarshalTestData {
pv := reflect.Zero(reflect.NewValue(test.out).Type()) pv := reflect.New(reflect.Typeof(test.out).Elem())
zv := reflect.Zero(pv.Type().Elem())
pv.Set(zv.Addr())
val := pv.Interface() val := pv.Interface()
_, err := Unmarshal(test.in, val) _, err := Unmarshal(test.in, val)
if err != nil { if err != nil {
......
...@@ -999,7 +999,6 @@ type Bad0 struct { ...@@ -999,7 +999,6 @@ type Bad0 struct {
C float64 C float64
} }
func TestInvalidField(t *testing.T) { func TestInvalidField(t *testing.T) {
var bad0 Bad0 var bad0 Bad0
bad0.CH = make(chan int) bad0.CH = make(chan int)
......
...@@ -581,7 +581,7 @@ func (dec *Decoder) decodeArray(atyp reflect.Type, state *decoderState, p uintpt ...@@ -581,7 +581,7 @@ func (dec *Decoder) decodeArray(atyp reflect.Type, state *decoderState, p uintpt
// unlike the other items we can't use a pointer directly. // unlike the other items we can't use a pointer directly.
func decodeIntoValue(state *decoderState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { func decodeIntoValue(state *decoderState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value {
instr := &decInstr{op, 0, indir, 0, ovfl} instr := &decInstr{op, 0, indir, 0, ovfl}
up := unsafe.Pointer(v.UnsafeAddr()) up := unsafe.Pointer(unsafeAddr(v))
if indir > 1 { if indir > 1 {
up = decIndirect(up, indir) up = decIndirect(up, indir)
} }
...@@ -608,8 +608,8 @@ func (dec *Decoder) decodeMap(mtyp reflect.Type, state *decoderState, p uintptr, ...@@ -608,8 +608,8 @@ func (dec *Decoder) decodeMap(mtyp reflect.Type, state *decoderState, p uintptr,
v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer(p))) v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer(p)))
n := int(state.decodeUint()) n := int(state.decodeUint())
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
key := decodeIntoValue(state, keyOp, keyIndir, reflect.Zero(mtyp.Key()), ovfl) key := decodeIntoValue(state, keyOp, keyIndir, allocValue(mtyp.Key()), ovfl)
elem := decodeIntoValue(state, elemOp, elemIndir, reflect.Zero(mtyp.Elem()), ovfl) elem := decodeIntoValue(state, elemOp, elemIndir, allocValue(mtyp.Elem()), ovfl)
v.SetMapIndex(key, elem) v.SetMapIndex(key, elem)
} }
} }
...@@ -686,8 +686,8 @@ func setInterfaceValue(ivalue reflect.Value, value reflect.Value) { ...@@ -686,8 +686,8 @@ func setInterfaceValue(ivalue reflect.Value, value reflect.Value) {
// Interfaces are encoded as the name of a concrete type followed by a value. // Interfaces are encoded as the name of a concrete type followed by a value.
// If the name is empty, the value is nil and no value is sent. // If the name is empty, the value is nil and no value is sent.
func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p uintptr, indir int) { func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p uintptr, indir int) {
// Create an interface reflect.Value. We need one even for the nil case. // Create a writable interface reflect.Value. We need one even for the nil case.
ivalue := reflect.Zero(ityp) ivalue := allocValue(ityp)
// Read the name of the concrete type. // Read the name of the concrete type.
b := make([]byte, state.decodeUint()) b := make([]byte, state.decodeUint())
state.b.Read(b) state.b.Read(b)
...@@ -712,7 +712,7 @@ func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p ui ...@@ -712,7 +712,7 @@ func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p ui
// in case we want to ignore the value by skipping it completely). // in case we want to ignore the value by skipping it completely).
state.decodeUint() state.decodeUint()
// Read the concrete value. // Read the concrete value.
value := reflect.Zero(typ) value := allocValue(typ)
dec.decodeValue(concreteId, value) dec.decodeValue(concreteId, value)
if dec.err != nil { if dec.err != nil {
error(dec.err) error(dec.err)
...@@ -1209,9 +1209,9 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) { ...@@ -1209,9 +1209,9 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) {
name := base.Name() name := base.Name()
errorf("gob: type mismatch: no fields matched compiling decoder for %s", name) errorf("gob: type mismatch: no fields matched compiling decoder for %s", name)
} }
dec.decodeStruct(engine, ut, uintptr(val.UnsafeAddr()), ut.indir) dec.decodeStruct(engine, ut, uintptr(unsafeAddr(val)), ut.indir)
} else { } else {
dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr())) dec.decodeSingle(engine, ut, uintptr(unsafeAddr(val)))
} }
} }
...@@ -1256,3 +1256,26 @@ func init() { ...@@ -1256,3 +1256,26 @@ func init() {
} }
decOpTable[reflect.Uintptr] = uop decOpTable[reflect.Uintptr] = uop
} }
// Gob assumes it can call UnsafeAddr on any Value
// in order to get a pointer it can copy data from.
// Values that have just been created and do not point
// into existing structs or slices cannot be addressed,
// so simulate it by returning a pointer to a copy.
// Each call allocates once.
func unsafeAddr(v reflect.Value) uintptr {
if v.CanAddr() {
return v.UnsafeAddr()
}
x := reflect.New(v.Type()).Elem()
x.Set(v)
return x.UnsafeAddr()
}
// Gob depends on being able to take the address
// of zeroed Values it creates, so use this wrapper instead
// of the standard reflect.Zero.
// Each call allocates once.
func allocValue(t reflect.Type) reflect.Value {
return reflect.New(t).Elem()
}
...@@ -171,12 +171,18 @@ func (dec *Decoder) Decode(e interface{}) os.Error { ...@@ -171,12 +171,18 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
return dec.DecodeValue(value) return dec.DecodeValue(value)
} }
// DecodeValue reads the next value from the connection and stores // DecodeValue reads the next value from the connection.
// it in the data represented by the reflection value. // If v is the zero reflect.Value (v.Kind() == Invalid), DecodeValue discards the value.
// The value must be the correct type for the next // Otherwise, it stores the value into v. In that case, v must represent
// data item received, or it may be nil, which means the // a non-nil pointer to data or be an assignable reflect.Value (v.CanSet())
// value will be discarded. func (dec *Decoder) DecodeValue(v reflect.Value) os.Error {
func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { if v.IsValid() {
if v.Kind() == reflect.Ptr && !v.IsNil() {
// That's okay, we'll store through the pointer.
} else if !v.CanSet() {
return os.ErrorString("gob: DecodeValue of unassignable value")
}
}
// Make sure we're single-threaded through here. // Make sure we're single-threaded through here.
dec.mutex.Lock() dec.mutex.Lock()
defer dec.mutex.Unlock() defer dec.mutex.Unlock()
...@@ -185,7 +191,7 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { ...@@ -185,7 +191,7 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
dec.err = nil dec.err = nil
id := dec.decodeTypeSequence(false) id := dec.decodeTypeSequence(false)
if dec.err == nil { if dec.err == nil {
dec.decodeValue(id, value) dec.decodeValue(id, v)
} }
return dec.err return dec.err
} }
......
...@@ -402,7 +402,7 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in ...@@ -402,7 +402,7 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in
if !v.IsValid() { if !v.IsValid() {
errorf("gob: encodeReflectValue: nil element") errorf("gob: encodeReflectValue: nil element")
} }
op(nil, state, unsafe.Pointer(v.UnsafeAddr())) op(nil, state, unsafe.Pointer(unsafeAddr(v)))
} }
// encodeMap encodes a map as unsigned count followed by key:value pairs. // encodeMap encodes a map as unsigned count followed by key:value pairs.
...@@ -695,8 +695,8 @@ func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInf ...@@ -695,8 +695,8 @@ func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInf
value = reflect.Indirect(value) value = reflect.Indirect(value)
} }
if !ut.isGobEncoder && value.Type().Kind() == reflect.Struct { if !ut.isGobEncoder && value.Type().Kind() == reflect.Struct {
enc.encodeStruct(b, engine, value.UnsafeAddr()) enc.encodeStruct(b, engine, unsafeAddr(value))
} else { } else {
enc.encodeSingle(b, engine, value.UnsafeAddr()) enc.encodeSingle(b, engine, unsafeAddr(value))
} }
} }
...@@ -170,7 +170,7 @@ func TestTypeToPtrType(t *testing.T) { ...@@ -170,7 +170,7 @@ func TestTypeToPtrType(t *testing.T) {
A int A int
} }
t0 := Type0{7} t0 := Type0{7}
t0p := (*Type0)(nil) t0p := new(Type0)
if err := encAndDec(t0, t0p); err != nil { if err := encAndDec(t0, t0p); err != nil {
t.Error(err) t.Error(err)
} }
......
...@@ -124,8 +124,7 @@ func (d *decodeState) unmarshal(v interface{}) (err os.Error) { ...@@ -124,8 +124,7 @@ func (d *decodeState) unmarshal(v interface{}) (err os.Error) {
rv := reflect.NewValue(v) rv := reflect.NewValue(v)
pv := rv pv := rv
if pv.Kind() != reflect.Ptr || if pv.Kind() != reflect.Ptr || pv.IsNil() {
pv.IsNil() {
return &InvalidUnmarshalError{reflect.Typeof(v)} return &InvalidUnmarshalError{reflect.Typeof(v)}
} }
...@@ -267,17 +266,17 @@ func (d *decodeState) indirect(v reflect.Value, wantptr bool) (Unmarshaler, refl ...@@ -267,17 +266,17 @@ func (d *decodeState) indirect(v reflect.Value, wantptr bool) (Unmarshaler, refl
v = iv.Elem() v = iv.Elem()
continue continue
} }
pv := v pv := v
if pv.Kind() != reflect.Ptr { if pv.Kind() != reflect.Ptr {
break break
} }
if pv.Elem().Kind() != reflect.Ptr && if pv.Elem().Kind() != reflect.Ptr && wantptr && pv.CanSet() && !isUnmarshaler {
wantptr && !isUnmarshaler {
return nil, pv return nil, pv
} }
if pv.IsNil() { if pv.IsNil() {
pv.Set(reflect.Zero(pv.Type().Elem()).Addr()) pv.Set(reflect.New(pv.Type().Elem()))
} }
if isUnmarshaler { if isUnmarshaler {
// Using v.Interface().(Unmarshaler) // Using v.Interface().(Unmarshaler)
...@@ -443,6 +442,8 @@ func (d *decodeState) object(v reflect.Value) { ...@@ -443,6 +442,8 @@ func (d *decodeState) object(v reflect.Value) {
return return
} }
var mapElem reflect.Value
for { for {
// Read opening " of string key or closing }. // Read opening " of string key or closing }.
op := d.scanWhile(scanSkipSpace) op := d.scanWhile(scanSkipSpace)
...@@ -466,7 +467,13 @@ func (d *decodeState) object(v reflect.Value) { ...@@ -466,7 +467,13 @@ func (d *decodeState) object(v reflect.Value) {
// Figure out field corresponding to key. // Figure out field corresponding to key.
var subv reflect.Value var subv reflect.Value
if mv.IsValid() { if mv.IsValid() {
subv = reflect.Zero(mv.Type().Elem()) elemType := mv.Type().Elem()
if !mapElem.IsValid() {
mapElem = reflect.New(elemType).Elem()
} else {
mapElem.Set(reflect.Zero(elemType))
}
subv = mapElem
} else { } else {
var f reflect.StructField var f reflect.StructField
var ok bool var ok bool
......
...@@ -138,8 +138,7 @@ func TestUnmarshal(t *testing.T) { ...@@ -138,8 +138,7 @@ func TestUnmarshal(t *testing.T) {
continue continue
} }
// v = new(right-type) // v = new(right-type)
v := reflect.NewValue(tt.ptr) v := reflect.New(reflect.Typeof(tt.ptr).Elem())
v.Set(reflect.Zero(v.Type().Elem()).Addr())
if err := Unmarshal([]byte(in), v.Interface()); !reflect.DeepEqual(err, tt.err) { if err := Unmarshal([]byte(in), v.Interface()); !reflect.DeepEqual(err, tt.err) {
t.Errorf("#%d: %v want %v", i, err, tt.err) t.Errorf("#%d: %v want %v", i, err, tt.err)
continue continue
......
...@@ -221,7 +221,7 @@ func (client *expClient) serveSend(hdr header) { ...@@ -221,7 +221,7 @@ func (client *expClient) serveSend(hdr header) {
return return
} }
// Create a new value for each received item. // Create a new value for each received item.
val := reflect.Zero(nch.ch.Type().Elem()) val := reflect.New(nch.ch.Type().Elem()).Elem()
if err := client.decode(val); err != nil { if err := client.decode(val); err != nil {
expLog("value decode:", err, "; type ", nch.ch.Type()) expLog("value decode:", err, "; type ", nch.ch.Type())
return return
......
...@@ -133,7 +133,7 @@ func (imp *Importer) run() { ...@@ -133,7 +133,7 @@ func (imp *Importer) run() {
ackHdr.SeqNum = hdr.SeqNum ackHdr.SeqNum = hdr.SeqNum
imp.encode(ackHdr, payAck, nil) imp.encode(ackHdr, payAck, nil)
// Create a new value for each received item. // Create a new value for each received item.
value := reflect.Zero(nch.ch.Type().Elem()) value := reflect.New(nch.ch.Type().Elem()).Elem()
if e := imp.decode(value); e != nil { if e := imp.decode(value); e != nil {
impLog("importer value decode:", e) impLog("importer value decode:", e)
return return
......
...@@ -297,12 +297,6 @@ type InvalidRequest struct{} ...@@ -297,12 +297,6 @@ type InvalidRequest struct{}
var invalidRequest = InvalidRequest{} var invalidRequest = InvalidRequest{}
func _new(t reflect.Type) reflect.Value {
v := reflect.Zero(t)
v.Set(reflect.Zero(t.Elem()).Addr())
return v
}
func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) { func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
resp := server.getResponse() resp := server.getResponse()
// Encode the response header // Encode the response header
...@@ -411,8 +405,8 @@ func (server *Server) ServeCodec(codec ServerCodec) { ...@@ -411,8 +405,8 @@ func (server *Server) ServeCodec(codec ServerCodec) {
} }
// Decode the argument value. // Decode the argument value.
argv := _new(mtype.ArgType) argv := reflect.New(mtype.ArgType.Elem())
replyv := _new(mtype.ReplyType) replyv := reflect.New(mtype.ReplyType.Elem())
err = codec.ReadRequestBody(argv.Interface()) err = codec.ReadRequestBody(argv.Interface())
if err != nil { if err != nil {
if err == os.EOF || err == io.ErrUnexpectedEOF { if err == os.EOF || err == io.ErrUnexpectedEOF {
......
...@@ -107,8 +107,8 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) { ...@@ -107,8 +107,8 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) {
if !ok { if !ok {
return reflect.Value{}, false return reflect.Value{}, false
} }
p := reflect.Zero(concrete) p := reflect.New(concrete.Elem())
p.Set(v.Addr()) p.Elem().Set(v)
return p, true return p, true
case reflect.Slice: case reflect.Slice:
numElems := rand.Intn(complexSize) numElems := rand.Intn(complexSize)
...@@ -129,7 +129,7 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) { ...@@ -129,7 +129,7 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) {
} }
return reflect.NewValue(string(codePoints)), true return reflect.NewValue(string(codePoints)), true
case reflect.Struct: case reflect.Struct:
s := reflect.Zero(t) s := reflect.New(t).Elem()
for i := 0; i < s.NumField(); i++ { for i := 0; i < s.NumField(); i++ {
v, ok := Value(concrete.Field(i).Type, rand) v, ok := Value(concrete.Field(i).Type, rand)
if !ok { if !ok {
......
...@@ -288,9 +288,7 @@ var pathTests = []interface{}{ ...@@ -288,9 +288,7 @@ var pathTests = []interface{}{
func TestUnmarshalPaths(t *testing.T) { func TestUnmarshalPaths(t *testing.T) {
for _, pt := range pathTests { for _, pt := range pathTests {
p := reflect.Zero(reflect.NewValue(pt).Type()) v := reflect.New(reflect.Typeof(pt).Elem()).Interface()
p.Set(reflect.Zero(p.Type().Elem()).Addr())
v := p.Interface()
if err := Unmarshal(StringReader(pathTestString), v); err != nil { if err := Unmarshal(StringReader(pathTestString), v); err != nil {
t.Fatalf("Unmarshal: %s", err) t.Fatalf("Unmarshal: %s", err)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment