Commit bc999933 authored by Alex Brainman's avatar Alex Brainman

net: prevent races during windows lookup calls

This only affects code (with exception of lookupProtocol)
that is only executed on older versions of Windows.

R=rsc, bradfitz
CC=golang-dev
https://golang.org/cl/7293043
parent 97916f11
......@@ -6,26 +6,17 @@ package net
import (
"os"
"sync"
"runtime"
"syscall"
"unsafe"
)
var (
protoentLock sync.Mutex
hostentLock sync.Mutex
serventLock sync.Mutex
)
var (
lookupPort = oldLookupPort
lookupIP = oldLookupIP
)
// lookupProtocol looks up IP protocol name and returns correspondent protocol number.
func lookupProtocol(name string) (proto int, err error) {
protoentLock.Lock()
defer protoentLock.Unlock()
func getprotobyname(name string) (proto int, err error) {
p, err := syscall.GetProtoByName(name)
if err != nil {
return 0, os.NewSyscallError("GetProtoByName", err)
......@@ -33,6 +24,25 @@ func lookupProtocol(name string) (proto int, err error) {
return int(p.Proto), nil
}
// lookupProtocol looks up IP protocol name and returns correspondent protocol number.
func lookupProtocol(name string) (proto int, err error) {
// GetProtoByName return value is stored in thread local storage.
// Start new os thread before the call to prevent races.
type result struct {
proto int
err error
}
ch := make(chan result)
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
proto, err := getprotobyname(name)
ch <- result{proto: proto, err: err}
}()
r := <-ch
return r.proto, r.err
}
func lookupHost(name string) (addrs []string, err error) {
ips, err := LookupIP(name)
if err != nil {
......@@ -45,9 +55,7 @@ func lookupHost(name string) (addrs []string, err error) {
return
}
func oldLookupIP(name string) (addrs []IP, err error) {
hostentLock.Lock()
defer hostentLock.Unlock()
func gethostbyname(name string) (addrs []IP, err error) {
h, err := syscall.GetHostByName(name)
if err != nil {
return nil, os.NewSyscallError("GetHostByName", err)
......@@ -66,6 +74,24 @@ func oldLookupIP(name string) (addrs []IP, err error) {
return addrs, nil
}
func oldLookupIP(name string) (addrs []IP, err error) {
// GetHostByName return value is stored in thread local storage.
// Start new os thread before the call to prevent races.
type result struct {
addrs []IP
err error
}
ch := make(chan result)
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
addrs, err := gethostbyname(name)
ch <- result{addrs: addrs, err: err}
}()
r := <-ch
return r.addrs, r.err
}
func newLookupIP(name string) (addrs []IP, err error) {
hints := syscall.AddrinfoW{
Family: syscall.AF_UNSPEC,
......@@ -95,15 +121,13 @@ func newLookupIP(name string) (addrs []IP, err error) {
return addrs, nil
}
func oldLookupPort(network, service string) (port int, err error) {
func getservbyname(network, service string) (port int, err error) {
switch network {
case "tcp4", "tcp6":
network = "tcp"
case "udp4", "udp6":
network = "udp"
}
serventLock.Lock()
defer serventLock.Unlock()
s, err := syscall.GetServByName(service, network)
if err != nil {
return 0, os.NewSyscallError("GetServByName", err)
......@@ -111,6 +135,24 @@ func oldLookupPort(network, service string) (port int, err error) {
return int(syscall.Ntohs(s.Port)), nil
}
func oldLookupPort(network, service string) (port int, err error) {
// GetServByName return value is stored in thread local storage.
// Start new os thread before the call to prevent races.
type result struct {
port int
err error
}
ch := make(chan result)
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
port, err := getservbyname(network, service)
ch <- result{port: port, err: err}
}()
r := <-ch
return r.port, r.err
}
func newLookupPort(network, service string) (port int, err error) {
var stype int32
switch network {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment