Commit ee8e28d3 authored by Alex Brainman's avatar Alex Brainman

syscall: another attempt to keep windows syscall pointers live

This approach was suggested in
https://golang.org/cl/138250043/#msg15.
Unlike current version of mksyscall_windows.go,
new code could be used in go.sys and other external
repos without help from asm.

LGTM=iant
R=golang-codereviews, iant, r
CC=golang-codereviews
https://golang.org/cl/143160046
parent 0a6f8b04
...@@ -138,8 +138,6 @@ func (p *Param) StringTmpVarCode() string { ...@@ -138,8 +138,6 @@ func (p *Param) StringTmpVarCode() string {
// TmpVarCode returns source code for temp variable. // TmpVarCode returns source code for temp variable.
func (p *Param) TmpVarCode() string { func (p *Param) TmpVarCode() string {
switch { switch {
case p.Type == "string":
return p.StringTmpVarCode()
case p.Type == "bool": case p.Type == "bool":
return p.BoolTmpVarCode() return p.BoolTmpVarCode()
case strings.HasPrefix(p.Type, "[]"): case strings.HasPrefix(p.Type, "[]"):
...@@ -149,19 +147,26 @@ func (p *Param) TmpVarCode() string { ...@@ -149,19 +147,26 @@ func (p *Param) TmpVarCode() string {
} }
} }
// TmpVarHelperCode returns source code for helper's temp variable.
func (p *Param) TmpVarHelperCode() string {
if p.Type != "string" {
return ""
}
return p.StringTmpVarCode()
}
// SyscallArgList returns source code fragments representing p parameter // SyscallArgList returns source code fragments representing p parameter
// in syscall. Slices are translated into 2 syscall parameters: pointer to // in syscall. Slices are translated into 2 syscall parameters: pointer to
// the first element and length. // the first element and length.
func (p *Param) SyscallArgList() []string { func (p *Param) SyscallArgList() []string {
t := p.HelperType()
var s string var s string
switch { switch {
case p.Type[0] == '*': case t[0] == '*':
s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name) s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name)
case p.Type == "string": case t == "bool":
s = fmt.Sprintf("unsafe.Pointer(%s)", p.tmpVar())
case p.Type == "bool":
s = p.tmpVar() s = p.tmpVar()
case strings.HasPrefix(p.Type, "[]"): case strings.HasPrefix(t, "[]"):
return []string{ return []string{
fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()), fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()),
fmt.Sprintf("uintptr(len(%s))", p.Name), fmt.Sprintf("uintptr(len(%s))", p.Name),
...@@ -177,6 +182,14 @@ func (p *Param) IsError() bool { ...@@ -177,6 +182,14 @@ func (p *Param) IsError() bool {
return p.Name == "err" && p.Type == "error" return p.Name == "err" && p.Type == "error"
} }
// HelperType returns type of parameter p used in helper function.
func (p *Param) HelperType() string {
if p.Type == "string" {
return p.fn.StrconvType()
}
return p.Type
}
// join concatenates parameters ps into a string with sep separator. // join concatenates parameters ps into a string with sep separator.
// Each parameter is converted into string by applying fn to it // Each parameter is converted into string by applying fn to it
// before conversion. // before conversion.
...@@ -454,6 +467,11 @@ func (f *Fn) ParamList() string { ...@@ -454,6 +467,11 @@ func (f *Fn) ParamList() string {
return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ") return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ")
} }
// HelperParamList returns source code for helper function f parameters.
func (f *Fn) HelperParamList() string {
return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ")
}
// ParamPrintList returns source code of trace printing part correspondent // ParamPrintList returns source code of trace printing part correspondent
// to syscall input parameters. // to syscall input parameters.
func (f *Fn) ParamPrintList() string { func (f *Fn) ParamPrintList() string {
...@@ -510,6 +528,19 @@ func (f *Fn) SyscallParamList() string { ...@@ -510,6 +528,19 @@ func (f *Fn) SyscallParamList() string {
return strings.Join(a, ", ") return strings.Join(a, ", ")
} }
// HelperCallParamList returns source code of call into function f helper.
func (f *Fn) HelperCallParamList() string {
a := make([]string, 0, len(f.Params))
for _, p := range f.Params {
s := p.Name
if p.Type == "string" {
s = p.tmpVar()
}
a = append(a, s)
}
return strings.Join(a, ", ")
}
// IsUTF16 is true, if f is W (utf16) function. It is false // IsUTF16 is true, if f is W (utf16) function. It is false
// for all A (ascii) functions. // for all A (ascii) functions.
func (f *Fn) IsUTF16() bool { func (f *Fn) IsUTF16() bool {
...@@ -533,6 +564,25 @@ func (f *Fn) StrconvType() string { ...@@ -533,6 +564,25 @@ func (f *Fn) StrconvType() string {
return "*byte" return "*byte"
} }
// HasStringParam is true, if f has at least one string parameter.
// Otherwise it is false.
func (f *Fn) HasStringParam() bool {
for _, p := range f.Params {
if p.Type == "string" {
return true
}
}
return false
}
// HelperName returns name of function f helper.
func (f *Fn) HelperName() string {
if !f.HasStringParam() {
return f.Name
}
return "_" + f.Name
}
// Source files and functions. // Source files and functions.
type Source struct { type Source struct {
Funcs []*Fn Funcs []*Fn
...@@ -666,7 +716,7 @@ import "syscall"{{end}} ...@@ -666,7 +716,7 @@ import "syscall"{{end}}
var ( var (
{{template "dlls" .}} {{template "dlls" .}}
{{template "funcnames" .}}) {{template "funcnames" .}})
{{range .Funcs}}{{template "funcbody" .}}{{end}} {{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}}
{{end}} {{end}}
{{/* help functions */}} {{/* help functions */}}
...@@ -677,16 +727,27 @@ var ( ...@@ -677,16 +727,27 @@ var (
{{define "funcnames"}}{{range .Funcs}} proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}") {{define "funcnames"}}{{range .Funcs}} proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}")
{{end}}{{end}} {{end}}{{end}}
{{define "helperbody"}}
func {{.Name}}({{.ParamList}}) {{template "results" .}}{
{{template "helpertmpvars" .}} return {{.HelperName}}({{.HelperCallParamList}})
}
{{end}}
{{define "funcbody"}} {{define "funcbody"}}
func {{.Name}}({{.ParamList}}) {{if .Rets.List}}{{.Rets.List}} {{end}}{ func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{
{{template "tmpvars" .}} {{template "syscall" .}} {{template "tmpvars" .}} {{template "syscall" .}}
{{template "seterror" .}}{{template "printtrace" .}} return {{template "seterror" .}}{{template "printtrace" .}} return
} }
{{end}} {{end}}
{{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}} {{.TmpVarHelperCode}}
{{end}}{{end}}{{end}}
{{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}} {{.TmpVarCode}} {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}} {{.TmpVarCode}}
{{end}}{{end}}{{end}} {{end}}{{end}}{{end}}
{{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}}
{{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}} {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
{{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}} {{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}}
......
...@@ -176,7 +176,11 @@ func LoadLibrary(libname string) (handle Handle, err error) { ...@@ -176,7 +176,11 @@ func LoadLibrary(libname string) (handle Handle, err error) {
if err != nil { if err != nil {
return return
} }
r0, _, e1 := Syscall(procLoadLibraryW.Addr(), 1, uintptr(unsafe.Pointer(_p0)), 0, 0) return _LoadLibrary(_p0)
}
func _LoadLibrary(libname *uint16) (handle Handle, err error) {
r0, _, e1 := Syscall(procLoadLibraryW.Addr(), 1, uintptr(unsafe.Pointer(libname)), 0, 0)
handle = Handle(r0) handle = Handle(r0)
if handle == 0 { if handle == 0 {
if e1 != 0 { if e1 != 0 {
...@@ -206,7 +210,11 @@ func GetProcAddress(module Handle, procname string) (proc uintptr, err error) { ...@@ -206,7 +210,11 @@ func GetProcAddress(module Handle, procname string) (proc uintptr, err error) {
if err != nil { if err != nil {
return return
} }
r0, _, e1 := Syscall(procGetProcAddress.Addr(), 2, uintptr(module), uintptr(unsafe.Pointer(_p0)), 0) return _GetProcAddress(module, _p0)
}
func _GetProcAddress(module Handle, procname *byte) (proc uintptr, err error) {
r0, _, e1 := Syscall(procGetProcAddress.Addr(), 2, uintptr(module), uintptr(unsafe.Pointer(procname)), 0)
proc = uintptr(r0) proc = uintptr(r0)
if proc == 0 { if proc == 0 {
if e1 != 0 { if e1 != 0 {
...@@ -1558,7 +1566,11 @@ func GetHostByName(name string) (h *Hostent, err error) { ...@@ -1558,7 +1566,11 @@ func GetHostByName(name string) (h *Hostent, err error) {
if err != nil { if err != nil {
return return
} }
r0, _, e1 := Syscall(procgethostbyname.Addr(), 1, uintptr(unsafe.Pointer(_p0)), 0, 0) return _GetHostByName(_p0)
}
func _GetHostByName(name *byte) (h *Hostent, err error) {
r0, _, e1 := Syscall(procgethostbyname.Addr(), 1, uintptr(unsafe.Pointer(name)), 0, 0)
h = (*Hostent)(unsafe.Pointer(r0)) h = (*Hostent)(unsafe.Pointer(r0))
if h == nil { if h == nil {
if e1 != 0 { if e1 != 0 {
...@@ -1581,7 +1593,11 @@ func GetServByName(name string, proto string) (s *Servent, err error) { ...@@ -1581,7 +1593,11 @@ func GetServByName(name string, proto string) (s *Servent, err error) {
if err != nil { if err != nil {
return return
} }
r0, _, e1 := Syscall(procgetservbyname.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_p1)), 0) return _GetServByName(_p0, _p1)
}
func _GetServByName(name *byte, proto *byte) (s *Servent, err error) {
r0, _, e1 := Syscall(procgetservbyname.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(proto)), 0)
s = (*Servent)(unsafe.Pointer(r0)) s = (*Servent)(unsafe.Pointer(r0))
if s == nil { if s == nil {
if e1 != 0 { if e1 != 0 {
...@@ -1605,7 +1621,11 @@ func GetProtoByName(name string) (p *Protoent, err error) { ...@@ -1605,7 +1621,11 @@ func GetProtoByName(name string) (p *Protoent, err error) {
if err != nil { if err != nil {
return return
} }
r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(_p0)), 0, 0) return _GetProtoByName(_p0)
}
func _GetProtoByName(name *byte) (p *Protoent, err error) {
r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(name)), 0, 0)
p = (*Protoent)(unsafe.Pointer(r0)) p = (*Protoent)(unsafe.Pointer(r0))
if p == nil { if p == nil {
if e1 != 0 { if e1 != 0 {
...@@ -1623,7 +1643,11 @@ func DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSR ...@@ -1623,7 +1643,11 @@ func DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSR
if status != nil { if status != nil {
return return
} }
r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(_p0)), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr))) return _DnsQuery(_p0, qtype, options, extra, qrs, pr)
}
func _DnsQuery(name *uint16, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status error) {
r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(name)), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr)))
if r0 != 0 { if r0 != 0 {
status = Errno(r0) status = Errno(r0)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment