Commit fae362e9 authored by Alexey Borzenkov's avatar Alexey Borzenkov Committed by Alex Brainman

os/user: faster user lookup on Windows

Trying to lookup user's display name with directory services can
take several seconds when user's computer is not in a domain.
As a workaround, check if computer is joined in a domain first,
and don't use directory services if it is not.
Additionally, don't leak tokens in user.Current().
Fixes #5298.

R=golang-dev, bradfitz, alex.brainman, lucio.dere
CC=golang-dev
https://golang.org/cl/8541047
parent 72b14cbb
......@@ -10,13 +10,24 @@ import (
"unsafe"
)
func lookupFullName(domain, username, domainAndUser string) (string, error) {
// try domain controller first
name, e := syscall.TranslateAccountName(domainAndUser,
func isDomainJoined() (bool, error) {
var domain *uint16
var status uint32
err := syscall.NetGetJoinInformation(nil, &domain, &status)
if err != nil {
return false, err
}
syscall.NetApiBufferFree((*byte)(unsafe.Pointer(domain)))
return status == syscall.NetSetupDomainName, nil
}
func lookupFullNameDomain(domainAndUser string) (string, error) {
return syscall.TranslateAccountName(domainAndUser,
syscall.NameSamCompatible, syscall.NameDisplay, 50)
if e != nil {
// domain lookup failed, perhaps this pc is not part of domain
d, e := syscall.UTF16PtrFromString(domain)
}
func lookupFullNameServer(servername, username string) (string, error) {
s, e := syscall.UTF16PtrFromString(servername)
if e != nil {
return "", e
}
......@@ -25,20 +36,35 @@ func lookupFullName(domain, username, domainAndUser string) (string, error) {
return "", e
}
var p *byte
e = syscall.NetUserGetInfo(d, u, 10, &p)
e = syscall.NetUserGetInfo(s, u, 10, &p)
if e != nil {
// path executed when a domain user is disconnected from the domain
// pretend username is fullname
return username, nil
return "", e
}
defer syscall.NetApiBufferFree(p)
i := (*syscall.UserInfo10)(unsafe.Pointer(p))
if i.FullName == nil {
return "", nil
}
name = syscall.UTF16ToString((*[1024]uint16)(unsafe.Pointer(i.FullName))[:])
name := syscall.UTF16ToString((*[1024]uint16)(unsafe.Pointer(i.FullName))[:])
return name, nil
}
func lookupFullName(domain, username, domainAndUser string) (string, error) {
joined, err := isDomainJoined()
if err == nil && joined {
name, err := lookupFullNameDomain(domainAndUser)
if err == nil {
return name, nil
}
}
name, err := lookupFullNameServer(domain, username)
if err == nil {
return name, nil
}
// domain worked neigher as a domain nor as a server
// could be domain server unavailable
// pretend username is fullname
return username, nil
}
func newUser(usid *syscall.SID, gid, dir string) (*User, error) {
......@@ -73,6 +99,7 @@ func current() (*User, error) {
if e != nil {
return nil, e
}
defer t.Close()
u, e := t.GetTokenUser()
if e != nil {
return nil, e
......
......@@ -58,6 +58,14 @@ func TranslateAccountName(username string, from, to uint32, initSize int) (strin
return UTF16ToString(b), nil
}
const (
// do not reorder
NetSetupUnknownStatus = iota
NetSetupUnjoined
NetSetupWorkgroupName
NetSetupDomainName
)
type UserInfo10 struct {
Name *uint16
Comment *uint16
......@@ -66,6 +74,7 @@ type UserInfo10 struct {
}
//sys NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **byte) (neterr error) = netapi32.NetUserGetInfo
//sys NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (neterr error) = netapi32.NetGetJoinInformation
//sys NetApiBufferFree(buf *byte) (neterr error) = netapi32.NetApiBufferFree
const (
......
......@@ -140,6 +140,7 @@ var (
procTranslateNameW = modsecur32.NewProc("TranslateNameW")
procGetUserNameExW = modsecur32.NewProc("GetUserNameExW")
procNetUserGetInfo = modnetapi32.NewProc("NetUserGetInfo")
procNetGetJoinInformation = modnetapi32.NewProc("NetGetJoinInformation")
procNetApiBufferFree = modnetapi32.NewProc("NetApiBufferFree")
procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW")
procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW")
......@@ -1613,6 +1614,14 @@ func NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **by
return
}
func NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (neterr error) {
r0, _, _ := Syscall(procNetGetJoinInformation.Addr(), 3, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(bufType)))
if r0 != 0 {
neterr = Errno(r0)
}
return
}
func NetApiBufferFree(buf *byte) (neterr error) {
r0, _, _ := Syscall(procNetApiBufferFree.Addr(), 1, uintptr(unsafe.Pointer(buf)), 0, 0)
if r0 != 0 {
......
......@@ -140,6 +140,7 @@ var (
procTranslateNameW = modsecur32.NewProc("TranslateNameW")
procGetUserNameExW = modsecur32.NewProc("GetUserNameExW")
procNetUserGetInfo = modnetapi32.NewProc("NetUserGetInfo")
procNetGetJoinInformation = modnetapi32.NewProc("NetGetJoinInformation")
procNetApiBufferFree = modnetapi32.NewProc("NetApiBufferFree")
procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW")
procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW")
......@@ -1613,6 +1614,14 @@ func NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **by
return
}
func NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (neterr error) {
r0, _, _ := Syscall(procNetGetJoinInformation.Addr(), 3, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(bufType)))
if r0 != 0 {
neterr = Errno(r0)
}
return
}
func NetApiBufferFree(buf *byte) (neterr error) {
r0, _, _ := Syscall(procNetApiBufferFree.Addr(), 1, uintptr(unsafe.Pointer(buf)), 0, 0)
if r0 != 0 {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment