Commit 42f07ff2 authored by Ross Light's avatar Ross Light Committed by Brad Fitzpatrick

os/user: add LookupGroup, LookupGroupId, and User.GroupIds functions

As part of local testing with a large group member list, I discovered
that the lookup functions don't resize their buffer if they receive
ERANGE.  I fixed this as a side-effect of this CL.

Thanks to @andrenth for the original CL.

Fixes #2617

Change-Id: Ie6aae2fe0a89eae5cce85786869a8acaa665ffe9
Reviewed-on: https://go-review.googlesource.com/19235Reviewed-by: default avatarIan Lance Taylor <iant@golang.org>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent ff555f11
...@@ -12,11 +12,28 @@ func Current() (*User, error) { ...@@ -12,11 +12,28 @@ func Current() (*User, error) {
// Lookup looks up a user by username. If the user cannot be found, the // Lookup looks up a user by username. If the user cannot be found, the
// returned error is of type UnknownUserError. // returned error is of type UnknownUserError.
func Lookup(username string) (*User, error) { func Lookup(username string) (*User, error) {
return lookup(username) return lookupUser(username)
} }
// LookupId looks up a user by userid. If the user cannot be found, the // LookupId looks up a user by userid. If the user cannot be found, the
// returned error is of type UnknownUserIdError. // returned error is of type UnknownUserIdError.
func LookupId(uid string) (*User, error) { func LookupId(uid string) (*User, error) {
return lookupId(uid) return lookupUserId(uid)
}
// LookupGroup looks up a group by name. If the group cannot be found, the
// returned error is of type UnknownGroupError.
func LookupGroup(name string) (*Group, error) {
return lookupGroup(name)
}
// LookupGroupId looks up a group by groupid. If the group cannot be found, the
// returned error is of type UnknownGroupIdError.
func LookupGroupId(gid string) (*Group, error) {
return lookupGroupId(gid)
}
// GroupIds returns the list of group IDs that the user is a member of.
func (u *User) GroupIds() ([]string, error) {
return listGroups(u)
} }
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build android
package user
import "errors"
func init() {
userImplemented = false
groupImplemented = false
}
func current() (*User, error) {
return nil, errors.New("user: Current not implemented on android")
}
func lookupUser(string) (*User, error) {
return nil, errors.New("user: Lookup not implemented on android")
}
func lookupUserId(string) (*User, error) {
return nil, errors.New("user: LookupId not implemented on android")
}
func lookupGroup(string) (*Group, error) {
return nil, errors.New("user: LookupGroup not implemented on android")
}
func lookupGroupId(string) (*Group, error) {
return nil, errors.New("user: LookupGroupId not implemented on android")
}
func listGroups(*User) ([]string, error) {
return nil, errors.New("user: GroupIds not implemented on android")
}
...@@ -18,6 +18,10 @@ const ( ...@@ -18,6 +18,10 @@ const (
userFile = "/dev/user" userFile = "/dev/user"
) )
func init() {
groupImplemented = false
}
func current() (*User, error) { func current() (*User, error) {
ubytes, err := ioutil.ReadFile(userFile) ubytes, err := ioutil.ReadFile(userFile)
if err != nil { if err != nil {
...@@ -37,10 +41,22 @@ func current() (*User, error) { ...@@ -37,10 +41,22 @@ func current() (*User, error) {
return u, nil return u, nil
} }
func lookup(username string) (*User, error) { func lookupUser(username string) (*User, error) {
return nil, syscall.EPLAN9
}
func lookupUserId(uid string) (*User, error) {
return nil, syscall.EPLAN9
}
func lookupGroup(groupname string) (*Group, error) {
return nil, syscall.EPLAN9
}
func lookupGroupId(string) (*Group, error) {
return nil, syscall.EPLAN9 return nil, syscall.EPLAN9
} }
func lookupId(uid string) (*User, error) { func listGroups(*User) ([]string, error) {
return nil, syscall.EPLAN9 return nil, syscall.EPLAN9
} }
...@@ -2,27 +2,37 @@ ...@@ -2,27 +2,37 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build !cgo,!windows,!plan9 android // +build !cgo,!windows,!plan9,!android
package user package user
import ( import "errors"
"fmt"
"runtime"
)
func init() { func init() {
implemented = false userImplemented = false
groupImplemented = false
} }
func current() (*User, error) { func current() (*User, error) {
return nil, fmt.Errorf("user: Current not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) return nil, errors.New("user: Current requires cgo")
} }
func lookup(username string) (*User, error) { func lookupUser(username string) (*User, error) {
return nil, fmt.Errorf("user: Lookup not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) return nil, errors.New("user: Lookup requires cgo")
} }
func lookupId(uid string) (*User, error) { func lookupUserId(uid string) (*User, error) {
return nil, fmt.Errorf("user: LookupId not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) return nil, errors.New("user: LookupId requires cgo")
}
func lookupGroup(groupname string) (*Group, error) {
return nil, errors.New("user: LookupGroup requires cgo")
}
func lookupGroupId(string) (*Group, error) {
return nil, errors.New("user: LookupGroupId requires cgo")
}
func listGroups(*User) ([]string, error) {
return nil, errors.New("user: GroupIds requires cgo")
} }
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
#include <unistd.h> #include <unistd.h>
#include <sys/types.h> #include <sys/types.h>
#include <pwd.h> #include <pwd.h>
#include <grp.h>
#include <stdlib.h> #include <stdlib.h>
static int mygetpwuid_r(int uid, struct passwd *pwd, static int mygetpwuid_r(int uid, struct passwd *pwd,
...@@ -31,76 +32,119 @@ static int mygetpwnam_r(const char *name, struct passwd *pwd, ...@@ -31,76 +32,119 @@ static int mygetpwnam_r(const char *name, struct passwd *pwd,
char *buf, size_t buflen, struct passwd **result) { char *buf, size_t buflen, struct passwd **result) {
return getpwnam_r(name, pwd, buf, buflen, result); return getpwnam_r(name, pwd, buf, buflen, result);
} }
*/
import "C"
func current() (*User, error) { static int mygetgrgid_r(int gid, struct group *grp,
return lookupUnix(syscall.Getuid(), "", false) char *buf, size_t buflen, struct group **result) {
return getgrgid_r(gid, grp, buf, buflen, result);
} }
func lookup(username string) (*User, error) { static int mygetgrouplist(const char *user, gid_t group, gid_t *groups,
return lookupUnix(-1, username, true) int *ngroups) {
return getgrouplist(user, group, groups, ngroups);
} }
*/
import "C"
func lookupId(uid string) (*User, error) { func current() (*User, error) {
i, e := strconv.Atoi(uid) return lookupUnixUid(syscall.Getuid())
if e != nil {
return nil, e
}
return lookupUnix(i, "", false)
} }
func lookupUnix(uid int, username string, lookupByName bool) (*User, error) { func lookupUser(username string) (*User, error) {
var pwd C.struct_passwd var pwd C.struct_passwd
var result *C.struct_passwd var result *C.struct_passwd
bufSize := C.sysconf(C._SC_GETPW_R_SIZE_MAX)
if bufSize == -1 {
// DragonFly and FreeBSD do not have _SC_GETPW_R_SIZE_MAX.
// Additionally, not all Linux systems have it, either. For
// example, the musl libc returns -1.
bufSize = 1024
}
if bufSize <= 0 || bufSize > 1<<20 {
return nil, fmt.Errorf("user: unreasonable _SC_GETPW_R_SIZE_MAX of %d", bufSize)
}
buf := C.malloc(C.size_t(bufSize))
defer C.free(buf)
var rv C.int
if lookupByName {
nameC := C.CString(username) nameC := C.CString(username)
defer C.free(unsafe.Pointer(nameC)) defer C.free(unsafe.Pointer(nameC))
buf := alloc(userBuffer)
defer buf.free()
err := retryWithBuffer(buf, func() syscall.Errno {
// mygetpwnam_r is a wrapper around getpwnam_r to avoid // mygetpwnam_r is a wrapper around getpwnam_r to avoid
// passing a size_t to getpwnam_r, because for unknown // passing a size_t to getpwnam_r, because for unknown
// reasons passing a size_t to getpwnam_r doesn't work on // reasons passing a size_t to getpwnam_r doesn't work on
// Solaris. // Solaris.
rv = C.mygetpwnam_r(nameC, return syscall.Errno(C.mygetpwnam_r(nameC,
&pwd, &pwd,
(*C.char)(buf), (*C.char)(buf.ptr),
C.size_t(bufSize), C.size_t(buf.size),
&result) &result))
if rv != 0 { })
return nil, fmt.Errorf("user: lookup username %s: %s", username, syscall.Errno(rv)) if err != nil {
return nil, fmt.Errorf("user: lookup username %s: %v", username, err)
} }
if result == nil { if result == nil {
return nil, UnknownUserError(username) return nil, UnknownUserError(username)
} }
} else { return buildUser(&pwd), err
}
func lookupUserId(uid string) (*User, error) {
i, e := strconv.Atoi(uid)
if e != nil {
return nil, e
}
return lookupUnixUid(i)
}
func lookupUnixUid(uid int) (*User, error) {
var pwd C.struct_passwd
var result *C.struct_passwd
buf := alloc(userBuffer)
defer buf.free()
err := retryWithBuffer(buf, func() syscall.Errno {
// mygetpwuid_r is a wrapper around getpwuid_r to // mygetpwuid_r is a wrapper around getpwuid_r to
// to avoid using uid_t because C.uid_t(uid) for // to avoid using uid_t because C.uid_t(uid) for
// unknown reasons doesn't work on linux. // unknown reasons doesn't work on linux.
rv = C.mygetpwuid_r(C.int(uid), return syscall.Errno(C.mygetpwuid_r(C.int(uid),
&pwd, &pwd,
(*C.char)(buf), (*C.char)(buf.ptr),
C.size_t(bufSize), C.size_t(buf.size),
&result) &result))
if rv != 0 { })
return nil, fmt.Errorf("user: lookup userid %d: %s", uid, syscall.Errno(rv)) if err != nil {
return nil, fmt.Errorf("user: lookup userid %d: %v", uid, err)
} }
if result == nil { if result == nil {
return nil, UnknownUserIdError(uid) return nil, UnknownUserIdError(uid)
} }
return buildUser(&pwd), nil
}
func listGroups(u *User) ([]string, error) {
ug, err := strconv.Atoi(u.Gid)
if err != nil {
return nil, fmt.Errorf("user: list groups for %s: invalid gid %q", u.Username, u.Gid)
}
userGID := C.gid_t(ug)
nameC := C.CString(u.Username)
defer C.free(unsafe.Pointer(nameC))
n := C.int(256)
gidsC := make([]C.gid_t, n)
rv := C.mygetgrouplist(nameC, userGID, &gidsC[0], &n)
if rv == -1 {
// More than initial buffer, but now n contains the correct size.
const maxGroups = 2048
if n > maxGroups {
return nil, fmt.Errorf("user: list groups for %s: member of more than %d groups", u.Username, maxGroups)
} }
gidsC = make([]C.gid_t, n)
rv := C.mygetgrouplist(nameC, userGID, &gidsC[0], &n)
if rv == -1 {
return nil, fmt.Errorf("user: list groups for %s failed (changed groups?)", u.Username)
}
}
gidsC = gidsC[:n]
gids := make([]string, 0, n)
for _, g := range gidsC[:n] {
gids = append(gids, strconv.Itoa(int(g)))
}
return gids, nil
}
func buildUser(pwd *C.struct_passwd) *User {
u := &User{ u := &User{
Uid: strconv.Itoa(int(pwd.pw_uid)), Uid: strconv.Itoa(int(pwd.pw_uid)),
Gid: strconv.Itoa(int(pwd.pw_gid)), Gid: strconv.Itoa(int(pwd.pw_gid)),
...@@ -115,5 +159,145 @@ func lookupUnix(uid int, username string, lookupByName bool) (*User, error) { ...@@ -115,5 +159,145 @@ func lookupUnix(uid int, username string, lookupByName bool) (*User, error) {
if i := strings.Index(u.Name, ","); i >= 0 { if i := strings.Index(u.Name, ","); i >= 0 {
u.Name = u.Name[:i] u.Name = u.Name[:i]
} }
return u, nil return u
}
func currentGroup() (*Group, error) {
return lookupUnixGid(syscall.Getgid())
}
func lookupGroup(groupname string) (*Group, error) {
var grp C.struct_group
var result *C.struct_group
buf := alloc(groupBuffer)
defer buf.free()
cname := C.CString(groupname)
defer C.free(unsafe.Pointer(cname))
err := retryWithBuffer(buf, func() syscall.Errno {
return syscall.Errno(C.getgrnam_r(cname,
&grp,
(*C.char)(buf.ptr),
C.size_t(buf.size),
&result))
})
if err != nil {
return nil, fmt.Errorf("user: lookup groupname %s: %v", groupname, err)
}
if result == nil {
return nil, UnknownGroupError(groupname)
}
return buildGroup(&grp), nil
}
func lookupGroupId(gid string) (*Group, error) {
i, e := strconv.Atoi(gid)
if e != nil {
return nil, e
}
return lookupUnixGid(i)
}
func lookupUnixGid(gid int) (*Group, error) {
var grp C.struct_group
var result *C.struct_group
buf := alloc(groupBuffer)
defer buf.free()
err := retryWithBuffer(buf, func() syscall.Errno {
// mygetgrgid_r is a wrapper around getgrgid_r to
// to avoid using gid_t because C.gid_t(gid) for
// unknown reasons doesn't work on linux.
return syscall.Errno(C.mygetgrgid_r(C.int(gid),
&grp,
(*C.char)(buf.ptr),
C.size_t(buf.size),
&result))
})
if err != nil {
return nil, fmt.Errorf("user: lookup groupid %d: %v", gid, err)
}
if result == nil {
return nil, UnknownGroupIdError(gid)
}
return buildGroup(&grp), nil
}
func buildGroup(grp *C.struct_group) *Group {
g := &Group{
Gid: strconv.Itoa(int(grp.gr_gid)),
Name: C.GoString(grp.gr_name),
}
return g
}
type bufferKind C.int
const (
userBuffer = bufferKind(C._SC_GETPW_R_SIZE_MAX)
groupBuffer = bufferKind(C._SC_GETGR_R_SIZE_MAX)
)
func (k bufferKind) initialSize() C.size_t {
sz := C.sysconf(C.int(k))
if sz == -1 {
// DragonFly and FreeBSD do not have _SC_GETPW_R_SIZE_MAX.
// Additionally, not all Linux systems have it, either. For
// example, the musl libc returns -1.
return 1024
}
if !isSizeReasonable(int64(sz)) {
// Truncate. If this truly isn't enough, retryWithBuffer will error on the first run.
return maxBufferSize
}
return C.size_t(sz)
}
type memBuffer struct {
ptr unsafe.Pointer
size C.size_t
}
func alloc(kind bufferKind) *memBuffer {
sz := kind.initialSize()
return &memBuffer{
ptr: C.malloc(sz),
size: sz,
}
}
func (mb *memBuffer) resize(newSize C.size_t) {
mb.ptr = C.realloc(mb.ptr, newSize)
mb.size = newSize
}
func (mb *memBuffer) free() {
C.free(mb.ptr)
}
// retryWithBuffer repeatedly calls f(), increasing the size of the
// buffer each time, until f succeeds, fails with a non-ERANGE error,
// or the buffer exceeds a reasonable limit.
func retryWithBuffer(buf *memBuffer, f func() syscall.Errno) error {
for {
errno := f()
if errno == 0 {
return nil
} else if errno != syscall.ERANGE {
return errno
}
newSize := buf.size * 2
if !isSizeReasonable(int64(newSize)) {
return fmt.Errorf("internal buffer exceeds %d bytes", maxBufferSize)
}
buf.resize(newSize)
}
}
const maxBufferSize = 1 << 20
func isSizeReasonable(sz int64) bool {
return sz > 0 && sz <= maxBufferSize
} }
...@@ -5,11 +5,16 @@ ...@@ -5,11 +5,16 @@
package user package user
import ( import (
"errors"
"fmt" "fmt"
"syscall" "syscall"
"unsafe" "unsafe"
) )
func init() {
groupImplemented = false
}
func isDomainJoined() (bool, error) { func isDomainJoined() (bool, error) {
var domain *uint16 var domain *uint16
var status uint32 var status uint32
...@@ -129,7 +134,7 @@ func newUserFromSid(usid *syscall.SID) (*User, error) { ...@@ -129,7 +134,7 @@ func newUserFromSid(usid *syscall.SID) (*User, error) {
return newUser(usid, gid, dir) return newUser(usid, gid, dir)
} }
func lookup(username string) (*User, error) { func lookupUser(username string) (*User, error) {
sid, _, t, e := syscall.LookupSID("", username) sid, _, t, e := syscall.LookupSID("", username)
if e != nil { if e != nil {
return nil, e return nil, e
...@@ -140,10 +145,22 @@ func lookup(username string) (*User, error) { ...@@ -140,10 +145,22 @@ func lookup(username string) (*User, error) {
return newUserFromSid(sid) return newUserFromSid(sid)
} }
func lookupId(uid string) (*User, error) { func lookupUserId(uid string) (*User, error) {
sid, e := syscall.StringToSid(uid) sid, e := syscall.StringToSid(uid)
if e != nil { if e != nil {
return nil, e return nil, e
} }
return newUserFromSid(sid) return newUserFromSid(sid)
} }
func lookupGroup(groupname string) (*Group, error) {
return nil, errors.New("user: LookupGroup not implemented on windows")
}
func lookupGroupId(string) (*Group, error) {
return nil, errors.New("user: LookupGroupId not implemented on windows")
}
func listGroups(*User) ([]string, error) {
return nil, errors.New("user: GroupIds not implemented on windows")
}
...@@ -9,23 +9,35 @@ import ( ...@@ -9,23 +9,35 @@ import (
"strconv" "strconv"
) )
var implemented = true // set to false by lookup_stubs.go's init var (
userImplemented = true // set to false by lookup_stubs.go's init
groupImplemented = true // set to false by lookup_stubs.go's init
)
// User represents a user account. // User represents a user account.
// //
// On posix systems Uid and Gid contain a decimal number // On POSIX systems Uid and Gid contain a decimal number
// representing uid and gid. On windows Uid and Gid // representing uid and gid. On windows Uid and Gid
// contain security identifier (SID) in a string format. // contain security identifier (SID) in a string format.
// On Plan 9, Uid, Gid, Username, and Name will be the // On Plan 9, Uid, Gid, Username, and Name will be the
// contents of /dev/user. // contents of /dev/user.
type User struct { type User struct {
Uid string // user id Uid string // user ID
Gid string // primary group id Gid string // primary group ID
Username string Username string
Name string Name string
HomeDir string HomeDir string
} }
// Group represents a grouping of users.
//
// On POSIX systems Gid contains a decimal number
// representing the group ID.
type Group struct {
Gid string // group ID
Name string // group name
}
// UnknownUserIdError is returned by LookupId when // UnknownUserIdError is returned by LookupId when
// a user cannot be found. // a user cannot be found.
type UnknownUserIdError int type UnknownUserIdError int
...@@ -41,3 +53,19 @@ type UnknownUserError string ...@@ -41,3 +53,19 @@ type UnknownUserError string
func (e UnknownUserError) Error() string { func (e UnknownUserError) Error() string {
return "user: unknown user " + string(e) return "user: unknown user " + string(e)
} }
// UnknownGroupIdError is returned by LookupGroupId when
// a group cannot be found.
type UnknownGroupIdError string
func (e UnknownGroupIdError) Error() string {
return "group: unknown groupid " + string(e)
}
// UnknownGroupError is returned by LookupGroup when
// a group cannot be found.
type UnknownGroupError string
func (e UnknownGroupError) Error() string {
return "group: unknown group " + string(e)
}
...@@ -9,14 +9,14 @@ import ( ...@@ -9,14 +9,14 @@ import (
"testing" "testing"
) )
func check(t *testing.T) { func checkUser(t *testing.T) {
if !implemented { if !userImplemented {
t.Skip("user: not implemented; skipping tests") t.Skip("user: not implemented; skipping tests")
} }
} }
func TestCurrent(t *testing.T) { func TestCurrent(t *testing.T) {
check(t) checkUser(t)
u, err := Current() u, err := Current()
if err != nil { if err != nil {
...@@ -53,7 +53,7 @@ func compare(t *testing.T, want, got *User) { ...@@ -53,7 +53,7 @@ func compare(t *testing.T, want, got *User) {
} }
func TestLookup(t *testing.T) { func TestLookup(t *testing.T) {
check(t) checkUser(t)
if runtime.GOOS == "plan9" { if runtime.GOOS == "plan9" {
t.Skipf("Lookup not implemented on %q", runtime.GOOS) t.Skipf("Lookup not implemented on %q", runtime.GOOS)
...@@ -71,7 +71,7 @@ func TestLookup(t *testing.T) { ...@@ -71,7 +71,7 @@ func TestLookup(t *testing.T) {
} }
func TestLookupId(t *testing.T) { func TestLookupId(t *testing.T) {
check(t) checkUser(t)
if runtime.GOOS == "plan9" { if runtime.GOOS == "plan9" {
t.Skipf("LookupId not implemented on %q", runtime.GOOS) t.Skipf("LookupId not implemented on %q", runtime.GOOS)
...@@ -87,3 +87,57 @@ func TestLookupId(t *testing.T) { ...@@ -87,3 +87,57 @@ func TestLookupId(t *testing.T) {
} }
compare(t, want, got) compare(t, want, got)
} }
func checkGroup(t *testing.T) {
if !groupImplemented {
t.Skip("user: group not implemented; skipping test")
}
}
func TestLookupGroup(t *testing.T) {
checkGroup(t)
user, err := Current()
if err != nil {
t.Fatalf("Current(): %v", err)
}
g1, err := LookupGroupId(user.Gid)
if err != nil {
t.Fatalf("LookupGroupId(%q): %v", user.Gid, err)
}
if g1.Gid != user.Gid {
t.Errorf("LookupGroupId(%q).Gid = %s; want %s", user.Gid, g1.Gid, user.Gid)
}
g2, err := LookupGroup(g1.Name)
if err != nil {
t.Fatalf("LookupGroup(%q): %v", g1.Name, err)
}
if g1.Gid != g2.Gid || g1.Name != g2.Name {
t.Errorf("LookupGroup(%q) = %+v; want %+v", g1.Name, g2, g1)
}
}
func TestGroupIds(t *testing.T) {
checkGroup(t)
user, err := Current()
if err != nil {
t.Fatalf("Current(): %v", err)
}
gids, err := user.GroupIds()
if err != nil {
t.Fatalf("%+v.GroupIds(): %v", user, err)
}
if !containsID(gids, user.Gid) {
t.Errorf("%+v.GroupIds() = %v; does not contain user GID %s", user, gids, user.Gid)
}
}
func containsID(ids []string, id string) bool {
for _, x := range ids {
if x == id {
return true
}
}
return false
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment