Commit 7b48020c authored by Ian Lance Taylor's avatar Ian Lance Taylor

cmd/cgo: check pointers for deferred C calls at the right time

We used to check time at the point of the defer statement. This change
fixes cgo to check them when the deferred function is executed.

Fixes #15921.

Change-Id: I72a10e26373cad6ad092773e9ebec4add29b9561
Reviewed-on: https://go-review.googlesource.com/23650
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarAustin Clements <austin@google.com>
parent 55559f15
...@@ -314,6 +314,14 @@ var ptrTests = []ptrTest{ ...@@ -314,6 +314,14 @@ var ptrTests = []ptrTest{
body: `i := 0; p := S{u:uintptr(unsafe.Pointer(&i))}; q := (*S)(C.malloc(C.size_t(unsafe.Sizeof(p)))); *q = p; C.f(unsafe.Pointer(q))`, body: `i := 0; p := S{u:uintptr(unsafe.Pointer(&i))}; q := (*S)(C.malloc(C.size_t(unsafe.Sizeof(p)))); *q = p; C.f(unsafe.Pointer(q))`,
fail: false, fail: false,
}, },
{
// Check deferred pointers when they are used, not
// when the defer statement is run.
name: "defer",
c: `typedef struct s { int *p; } s; void f(s *ps) {}`,
body: `p := &C.s{}; defer C.f(p); p.p = new(C.int)`,
fail: true,
},
} }
func main() { func main() {
......
...@@ -172,7 +172,7 @@ func (f *File) saveExprs(x interface{}, context string) { ...@@ -172,7 +172,7 @@ func (f *File) saveExprs(x interface{}, context string) {
f.saveRef(x, context) f.saveRef(x, context)
} }
case *ast.CallExpr: case *ast.CallExpr:
f.saveCall(x) f.saveCall(x, context)
} }
} }
...@@ -220,7 +220,7 @@ func (f *File) saveRef(n *ast.Expr, context string) { ...@@ -220,7 +220,7 @@ func (f *File) saveRef(n *ast.Expr, context string) {
} }
// Save calls to C.xxx for later processing. // Save calls to C.xxx for later processing.
func (f *File) saveCall(call *ast.CallExpr) { func (f *File) saveCall(call *ast.CallExpr, context string) {
sel, ok := call.Fun.(*ast.SelectorExpr) sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok { if !ok {
return return
...@@ -228,7 +228,8 @@ func (f *File) saveCall(call *ast.CallExpr) { ...@@ -228,7 +228,8 @@ func (f *File) saveCall(call *ast.CallExpr) {
if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" { if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
return return
} }
f.Calls = append(f.Calls, call) c := &Call{Call: call, Deferred: context == "defer"}
f.Calls = append(f.Calls, c)
} }
// If a function should be exported add it to ExpFunc. // If a function should be exported add it to ExpFunc.
...@@ -401,7 +402,7 @@ func (f *File) walk(x interface{}, context string, visit func(*File, interface{} ...@@ -401,7 +402,7 @@ func (f *File) walk(x interface{}, context string, visit func(*File, interface{}
case *ast.GoStmt: case *ast.GoStmt:
f.walk(n.Call, "expr", visit) f.walk(n.Call, "expr", visit)
case *ast.DeferStmt: case *ast.DeferStmt:
f.walk(n.Call, "expr", visit) f.walk(n.Call, "defer", visit)
case *ast.ReturnStmt: case *ast.ReturnStmt:
f.walk(n.Results, "expr", visit) f.walk(n.Results, "expr", visit)
case *ast.BranchStmt: case *ast.BranchStmt:
......
...@@ -581,7 +581,7 @@ func (p *Package) mangleName(n *Name) { ...@@ -581,7 +581,7 @@ func (p *Package) mangleName(n *Name) {
func (p *Package) rewriteCalls(f *File) { func (p *Package) rewriteCalls(f *File) {
for _, call := range f.Calls { for _, call := range f.Calls {
// This is a call to C.xxx; set goname to "xxx". // This is a call to C.xxx; set goname to "xxx".
goname := call.Fun.(*ast.SelectorExpr).Sel.Name goname := call.Call.Fun.(*ast.SelectorExpr).Sel.Name
if goname == "malloc" { if goname == "malloc" {
continue continue
} }
...@@ -596,37 +596,58 @@ func (p *Package) rewriteCalls(f *File) { ...@@ -596,37 +596,58 @@ func (p *Package) rewriteCalls(f *File) {
// rewriteCall rewrites one call to add pointer checks. We replace // rewriteCall rewrites one call to add pointer checks. We replace
// each pointer argument x with _cgoCheckPointer(x).(T). // each pointer argument x with _cgoCheckPointer(x).(T).
func (p *Package) rewriteCall(f *File, call *ast.CallExpr, name *Name) { func (p *Package) rewriteCall(f *File, call *Call, name *Name) {
any := false
for i, param := range name.FuncType.Params { for i, param := range name.FuncType.Params {
if len(call.Args) <= i { if len(call.Call.Args) <= i {
// Avoid a crash; this will be caught when the // Avoid a crash; this will be caught when the
// generated file is compiled. // generated file is compiled.
return return
} }
if p.needsPointerCheck(f, param.Go, call.Call.Args[i]) {
any = true
break
}
}
if !any {
return
}
// An untyped nil does not need a pointer check, and // We need to rewrite this call.
// when _cgoCheckPointer returns the untyped nil the //
// type assertion we are going to insert will fail. // We are going to rewrite C.f(p) to C.f(_cgoCheckPointer(p)).
// Easier to just skip nil arguments. // If the call to C.f is deferred, that will check p at the
// TODO: Note that this fails if nil is shadowed. // point of the defer statement, not when the function is called, so
if id, ok := call.Args[i].(*ast.Ident); ok && id.Name == "nil" { // rewrite to func(_cgo0 ptype) { C.f(_cgoCheckPointer(_cgo0)) }(p)
continue
var dargs []ast.Expr
if call.Deferred {
dargs = make([]ast.Expr, len(name.FuncType.Params))
}
for i, param := range name.FuncType.Params {
origArg := call.Call.Args[i]
darg := origArg
if call.Deferred {
dargs[i] = darg
darg = ast.NewIdent(fmt.Sprintf("_cgo%d", i))
call.Call.Args[i] = darg
} }
if !p.needsPointerCheck(f, param.Go) { if !p.needsPointerCheck(f, param.Go, origArg) {
continue continue
} }
c := &ast.CallExpr{ c := &ast.CallExpr{
Fun: ast.NewIdent("_cgoCheckPointer"), Fun: ast.NewIdent("_cgoCheckPointer"),
Args: []ast.Expr{ Args: []ast.Expr{
call.Args[i], darg,
}, },
} }
// Add optional additional arguments for an address // Add optional additional arguments for an address
// expression. // expression.
c.Args = p.checkAddrArgs(f, c.Args, call.Args[i]) c.Args = p.checkAddrArgs(f, c.Args, origArg)
// _cgoCheckPointer returns interface{}. // _cgoCheckPointer returns interface{}.
// We need to type assert that to the type we want. // We need to type assert that to the type we want.
...@@ -664,14 +685,73 @@ func (p *Package) rewriteCall(f *File, call *ast.CallExpr, name *Name) { ...@@ -664,14 +685,73 @@ func (p *Package) rewriteCall(f *File, call *ast.CallExpr, name *Name) {
} }
} }
call.Args[i] = arg call.Call.Args[i] = arg
}
if call.Deferred {
params := make([]*ast.Field, len(name.FuncType.Params))
for i, param := range name.FuncType.Params {
ptype := param.Go
if p.hasUnsafePointer(ptype) {
// Avoid generating unsafe.Pointer by using
// interface{}. This works because we are
// going to call a _cgoCheckPointer function
// anyhow.
ptype = &ast.InterfaceType{
Methods: &ast.FieldList{},
}
}
params[i] = &ast.Field{
Names: []*ast.Ident{
ast.NewIdent(fmt.Sprintf("_cgo%d", i)),
},
Type: ptype,
}
}
dbody := &ast.CallExpr{
Fun: call.Call.Fun,
Args: call.Call.Args,
}
call.Call.Fun = &ast.FuncLit{
Type: &ast.FuncType{
Params: &ast.FieldList{
List: params,
},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ExprStmt{
X: dbody,
},
},
},
}
call.Call.Args = dargs
call.Call.Lparen = token.NoPos
call.Call.Rparen = token.NoPos
// There is a Ref pointing to the old call.Call.Fun.
for _, ref := range f.Ref {
if ref.Expr == &call.Call.Fun {
ref.Expr = &dbody.Fun
}
}
} }
} }
// needsPointerCheck returns whether the type t needs a pointer check. // needsPointerCheck returns whether the type t needs a pointer check.
// This is true if t is a pointer and if the value to which it points // This is true if t is a pointer and if the value to which it points
// might contain a pointer. // might contain a pointer.
func (p *Package) needsPointerCheck(f *File, t ast.Expr) bool { func (p *Package) needsPointerCheck(f *File, t ast.Expr, arg ast.Expr) bool {
// An untyped nil does not need a pointer check, and when
// _cgoCheckPointer returns the untyped nil the type assertion we
// are going to insert will fail. Easier to just skip nil arguments.
// TODO: Note that this fails if nil is shadowed.
if id, ok := arg.(*ast.Ident); ok && id.Name == "nil" {
return false
}
return p.hasPointer(f, t, true) return p.hasPointer(f, t, true)
} }
......
...@@ -52,7 +52,7 @@ type File struct { ...@@ -52,7 +52,7 @@ type File struct {
Package string // Package name Package string // Package name
Preamble string // C preamble (doc comment on import "C") Preamble string // C preamble (doc comment on import "C")
Ref []*Ref // all references to C.xxx in AST Ref []*Ref // all references to C.xxx in AST
Calls []*ast.CallExpr // all calls to C.xxx in AST Calls []*Call // all calls to C.xxx in AST
ExpFunc []*ExpFunc // exported functions for this file ExpFunc []*ExpFunc // exported functions for this file
Name map[string]*Name // map from Go name to Name Name map[string]*Name // map from Go name to Name
} }
...@@ -66,6 +66,12 @@ func nameKeys(m map[string]*Name) []string { ...@@ -66,6 +66,12 @@ func nameKeys(m map[string]*Name) []string {
return ks return ks
} }
// A Call refers to a call of a C.xxx function in the AST.
type Call struct {
Call *ast.CallExpr
Deferred bool
}
// A Ref refers to an expression of the form C.xxx in the AST. // A Ref refers to an expression of the form C.xxx in the AST.
type Ref struct { type Ref struct {
Name *Name Name *Name
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment