Commit fae362e9 authored by Alexey Borzenkov's avatar Alexey Borzenkov Committed by Alex Brainman

os/user: faster user lookup on Windows

Trying to lookup user's display name with directory services can
take several seconds when user's computer is not in a domain.
As a workaround, check if computer is joined in a domain first,
and don't use directory services if it is not.
Additionally, don't leak tokens in user.Current().
Fixes #5298.

R=golang-dev, bradfitz, alex.brainman, lucio.dere
CC=golang-dev
https://golang.org/cl/8541047
parent 72b14cbb
...@@ -10,13 +10,24 @@ import ( ...@@ -10,13 +10,24 @@ import (
"unsafe" "unsafe"
) )
func lookupFullName(domain, username, domainAndUser string) (string, error) { func isDomainJoined() (bool, error) {
// try domain controller first var domain *uint16
name, e := syscall.TranslateAccountName(domainAndUser, var status uint32
err := syscall.NetGetJoinInformation(nil, &domain, &status)
if err != nil {
return false, err
}
syscall.NetApiBufferFree((*byte)(unsafe.Pointer(domain)))
return status == syscall.NetSetupDomainName, nil
}
func lookupFullNameDomain(domainAndUser string) (string, error) {
return syscall.TranslateAccountName(domainAndUser,
syscall.NameSamCompatible, syscall.NameDisplay, 50) syscall.NameSamCompatible, syscall.NameDisplay, 50)
if e != nil { }
// domain lookup failed, perhaps this pc is not part of domain
d, e := syscall.UTF16PtrFromString(domain) func lookupFullNameServer(servername, username string) (string, error) {
s, e := syscall.UTF16PtrFromString(servername)
if e != nil { if e != nil {
return "", e return "", e
} }
...@@ -25,20 +36,35 @@ func lookupFullName(domain, username, domainAndUser string) (string, error) { ...@@ -25,20 +36,35 @@ func lookupFullName(domain, username, domainAndUser string) (string, error) {
return "", e return "", e
} }
var p *byte var p *byte
e = syscall.NetUserGetInfo(d, u, 10, &p) e = syscall.NetUserGetInfo(s, u, 10, &p)
if e != nil { if e != nil {
// path executed when a domain user is disconnected from the domain return "", e
// pretend username is fullname
return username, nil
} }
defer syscall.NetApiBufferFree(p) defer syscall.NetApiBufferFree(p)
i := (*syscall.UserInfo10)(unsafe.Pointer(p)) i := (*syscall.UserInfo10)(unsafe.Pointer(p))
if i.FullName == nil { if i.FullName == nil {
return "", nil return "", nil
} }
name = syscall.UTF16ToString((*[1024]uint16)(unsafe.Pointer(i.FullName))[:]) name := syscall.UTF16ToString((*[1024]uint16)(unsafe.Pointer(i.FullName))[:])
return name, nil
}
func lookupFullName(domain, username, domainAndUser string) (string, error) {
joined, err := isDomainJoined()
if err == nil && joined {
name, err := lookupFullNameDomain(domainAndUser)
if err == nil {
return name, nil
} }
}
name, err := lookupFullNameServer(domain, username)
if err == nil {
return name, nil return name, nil
}
// domain worked neigher as a domain nor as a server
// could be domain server unavailable
// pretend username is fullname
return username, nil
} }
func newUser(usid *syscall.SID, gid, dir string) (*User, error) { func newUser(usid *syscall.SID, gid, dir string) (*User, error) {
...@@ -73,6 +99,7 @@ func current() (*User, error) { ...@@ -73,6 +99,7 @@ func current() (*User, error) {
if e != nil { if e != nil {
return nil, e return nil, e
} }
defer t.Close()
u, e := t.GetTokenUser() u, e := t.GetTokenUser()
if e != nil { if e != nil {
return nil, e return nil, e
......
...@@ -58,6 +58,14 @@ func TranslateAccountName(username string, from, to uint32, initSize int) (strin ...@@ -58,6 +58,14 @@ func TranslateAccountName(username string, from, to uint32, initSize int) (strin
return UTF16ToString(b), nil return UTF16ToString(b), nil
} }
const (
// do not reorder
NetSetupUnknownStatus = iota
NetSetupUnjoined
NetSetupWorkgroupName
NetSetupDomainName
)
type UserInfo10 struct { type UserInfo10 struct {
Name *uint16 Name *uint16
Comment *uint16 Comment *uint16
...@@ -66,6 +74,7 @@ type UserInfo10 struct { ...@@ -66,6 +74,7 @@ type UserInfo10 struct {
} }
//sys NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **byte) (neterr error) = netapi32.NetUserGetInfo //sys NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **byte) (neterr error) = netapi32.NetUserGetInfo
//sys NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (neterr error) = netapi32.NetGetJoinInformation
//sys NetApiBufferFree(buf *byte) (neterr error) = netapi32.NetApiBufferFree //sys NetApiBufferFree(buf *byte) (neterr error) = netapi32.NetApiBufferFree
const ( const (
......
...@@ -140,6 +140,7 @@ var ( ...@@ -140,6 +140,7 @@ var (
procTranslateNameW = modsecur32.NewProc("TranslateNameW") procTranslateNameW = modsecur32.NewProc("TranslateNameW")
procGetUserNameExW = modsecur32.NewProc("GetUserNameExW") procGetUserNameExW = modsecur32.NewProc("GetUserNameExW")
procNetUserGetInfo = modnetapi32.NewProc("NetUserGetInfo") procNetUserGetInfo = modnetapi32.NewProc("NetUserGetInfo")
procNetGetJoinInformation = modnetapi32.NewProc("NetGetJoinInformation")
procNetApiBufferFree = modnetapi32.NewProc("NetApiBufferFree") procNetApiBufferFree = modnetapi32.NewProc("NetApiBufferFree")
procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW") procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW")
procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW") procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW")
...@@ -1613,6 +1614,14 @@ func NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **by ...@@ -1613,6 +1614,14 @@ func NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **by
return return
} }
func NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (neterr error) {
r0, _, _ := Syscall(procNetGetJoinInformation.Addr(), 3, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(bufType)))
if r0 != 0 {
neterr = Errno(r0)
}
return
}
func NetApiBufferFree(buf *byte) (neterr error) { func NetApiBufferFree(buf *byte) (neterr error) {
r0, _, _ := Syscall(procNetApiBufferFree.Addr(), 1, uintptr(unsafe.Pointer(buf)), 0, 0) r0, _, _ := Syscall(procNetApiBufferFree.Addr(), 1, uintptr(unsafe.Pointer(buf)), 0, 0)
if r0 != 0 { if r0 != 0 {
......
...@@ -140,6 +140,7 @@ var ( ...@@ -140,6 +140,7 @@ var (
procTranslateNameW = modsecur32.NewProc("TranslateNameW") procTranslateNameW = modsecur32.NewProc("TranslateNameW")
procGetUserNameExW = modsecur32.NewProc("GetUserNameExW") procGetUserNameExW = modsecur32.NewProc("GetUserNameExW")
procNetUserGetInfo = modnetapi32.NewProc("NetUserGetInfo") procNetUserGetInfo = modnetapi32.NewProc("NetUserGetInfo")
procNetGetJoinInformation = modnetapi32.NewProc("NetGetJoinInformation")
procNetApiBufferFree = modnetapi32.NewProc("NetApiBufferFree") procNetApiBufferFree = modnetapi32.NewProc("NetApiBufferFree")
procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW") procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW")
procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW") procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW")
...@@ -1613,6 +1614,14 @@ func NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **by ...@@ -1613,6 +1614,14 @@ func NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **by
return return
} }
func NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (neterr error) {
r0, _, _ := Syscall(procNetGetJoinInformation.Addr(), 3, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(bufType)))
if r0 != 0 {
neterr = Errno(r0)
}
return
}
func NetApiBufferFree(buf *byte) (neterr error) { func NetApiBufferFree(buf *byte) (neterr error) {
r0, _, _ := Syscall(procNetApiBufferFree.Addr(), 1, uintptr(unsafe.Pointer(buf)), 0, 0) r0, _, _ := Syscall(procNetApiBufferFree.Addr(), 1, uintptr(unsafe.Pointer(buf)), 0, 0)
if r0 != 0 { if r0 != 0 {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment