Commit 0e95842a authored by Denis Bilenko's avatar Denis Bilenko

a number of compatibility improvements in resolver_ares and ares.pyx:

resolver_ares.py:
- gethostbyname now handles '' (empty string)
- getaddrinfo now handles integer ports of type string, e.g "25". Thanks to kconor.
- getaddrinfo now converts UnicodeEncodeError into error('Int or String expected')
- getaddrinfo now uses the lowest 16 bits of passed port integer, to mimic socketmodule.c
- getnameinfo calls getaddrinfo to process arguments, similar to socketmodule.c
- gethostbyaddr also uses getaddrinfo to process arguments

ares.pyx:
- added InvalidIP exception
- gethostbyaddr and getnameinfo now raise InvalidIP immediatelly instead of passing it through callback
parent 51b2d8d5
...@@ -142,6 +142,10 @@ cpdef strerror(code): ...@@ -142,6 +142,10 @@ cpdef strerror(code):
return '%s: %s' % (_ares_errors.get(code) or code, cares.ares_strerror(code)) return '%s: %s' % (_ares_errors.get(code) or code, cares.ares_strerror(code))
class InvalidIP(ValueError):
pass
cdef void gevent_sock_state_callback(void *data, int s, int read, int write): cdef void gevent_sock_state_callback(void *data, int s, int read, int write):
if not data: if not data:
return return
...@@ -318,7 +322,7 @@ cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresCh ...@@ -318,7 +322,7 @@ cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresCh
elif cares.ares_inet_pton(AF_INET6, string, &c_servers[index].addr) > 0: elif cares.ares_inet_pton(AF_INET6, string, &c_servers[index].addr) > 0:
c_servers[index].family = AF_INET6 c_servers[index].family = AF_INET6
else: else:
raise ValueError('illegal IP address string: %r' % string) raise InvalidIP(repr(string))
c_servers[index].next = &c_servers[index] + 1 c_servers[index].next = &c_servers[index] + 1
index += 1 index += 1
if index >= length: if index >= length:
...@@ -398,9 +402,7 @@ cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresCh ...@@ -398,9 +402,7 @@ cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresCh
family = AF_INET6 family = AF_INET6
length = 16 length = 16
else: else:
# XXX raise immediatelly? raise InvalidIP(repr(addr))
callback(result(exception=ValueError('illegal IP address string: %r' % addr)))
return
cdef object arg = (self, callback) cdef object arg = (self, callback)
Py_INCREF(<PyObjectPtr>arg) Py_INCREF(<PyObjectPtr>arg)
cares.ares_gethostbyaddr(self.channel, addr_packed, length, family, <void*>gevent_ares_host_callback, <void*>arg) cares.ares_gethostbyaddr(self.channel, addr_packed, length, family, <void*>gevent_ares_host_callback, <void*>arg)
...@@ -421,9 +423,7 @@ cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresCh ...@@ -421,9 +423,7 @@ cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresCh
raise gaierror(-8, 'Invalid value for port: %r' % port) raise gaierror(-8, 'Invalid value for port: %r' % port)
cdef int length = gevent_make_sockaddr(hostp, port, flowinfo, scope_id, &sa6) cdef int length = gevent_make_sockaddr(hostp, port, flowinfo, scope_id, &sa6)
if length <= 0: if length <= 0:
# XXX raise immediatelly? like TypeError and gaierror raised above? raise InvalidIP(repr(hostp))
callback(result(exception=ValueError('illegal IP address string: %r' % hostp)))
return
cdef object arg = (self, callback) cdef object arg = (self, callback)
Py_INCREF(<PyObjectPtr>arg) Py_INCREF(<PyObjectPtr>arg)
cdef sockaddr_t* x = <sockaddr_t*>&sa6 cdef sockaddr_t* x = <sockaddr_t*>&sa6
......
...@@ -3,8 +3,8 @@ import os ...@@ -3,8 +3,8 @@ import os
import sys import sys
from _socket import getservbyname, getaddrinfo, gaierror, error from _socket import getservbyname, getaddrinfo, gaierror, error
from gevent.hub import Waiter, get_hub, basestring from gevent.hub import Waiter, get_hub, basestring
from gevent.socket import AF_UNSPEC, AF_INET, AF_INET6, SOCK_STREAM, SOCK_DGRAM, SOCK_RAW, AI_NUMERICHOST, EAI_SERVICE from gevent.socket import AF_UNSPEC, AF_INET, AF_INET6, SOCK_STREAM, SOCK_DGRAM, SOCK_RAW, AI_NUMERICHOST, EAI_SERVICE, AI_PASSIVE
from gevent.ares import channel from gevent.ares import channel, InvalidIP
__all__ = ['Resolver'] __all__ = ['Resolver']
...@@ -41,6 +41,7 @@ class Resolver(object): ...@@ -41,6 +41,7 @@ class Resolver(object):
self.fork_watcher.stop() self.fork_watcher.stop()
def gethostbyname(self, hostname, family=AF_INET): def gethostbyname(self, hostname, family=AF_INET):
hostname = _resolve_special(hostname, family)
return self.gethostbyname_ex(hostname, family)[-1][0] return self.gethostbyname_ex(hostname, family)[-1][0]
def gethostbyname_ex(self, hostname, family=AF_INET): def gethostbyname_ex(self, hostname, family=AF_INET):
...@@ -58,27 +59,37 @@ class Resolver(object): ...@@ -58,27 +59,37 @@ class Resolver(object):
def _lookup_port(self, port, socktype): def _lookup_port(self, port, socktype):
if isinstance(port, basestring): if isinstance(port, basestring):
try: try:
if socktype == 0: port = int(port)
try: except ValueError:
try:
if socktype == 0:
try:
port = getservbyname(port, 'tcp')
socktype = SOCK_STREAM
except error:
port = getservbyname(port, 'udp')
socktype = SOCK_DGRAM
elif socktype == SOCK_STREAM:
port = getservbyname(port, 'tcp') port = getservbyname(port, 'tcp')
socktype = SOCK_STREAM elif socktype == SOCK_DGRAM:
except error:
port = getservbyname(port, 'udp') port = getservbyname(port, 'udp')
socktype = SOCK_DGRAM else:
elif socktype == SOCK_STREAM: raise gaierror(EAI_SERVICE, 'Servname not supported for ai_socktype')
port = getservbyname(port, 'tcp') except error:
elif socktype == SOCK_DGRAM: ex = sys.exc_info()[1]
port = getservbyname(port, 'udp') if 'not found' in str(ex):
else: raise gaierror(EAI_SERVICE, 'Servname not supported for ai_socktype')
raise gaierror(EAI_SERVICE, 'Servname not supported for ai_socktype') else:
except error: raise gaierror(str(ex))
ex = sys.exc_info()[1] except UnicodeEncodeError:
if 'not found' in str(ex): raise error('Int or String expected')
raise gaierror(EAI_SERVICE, 'Servname not supported for ai_socktype')
else:
raise gaierror(str(ex))
elif port is None: elif port is None:
port = 0 port = 0
elif isinstance(port, int):
pass
else:
raise error('Int or String expected')
port = int(port % 65536)
return port, socktype return port, socktype
def _getaddrinfo(self, host, port, family=0, socktype=0, proto=0, flags=0): def _getaddrinfo(self, host, port, family=0, socktype=0, proto=0, flags=0):
...@@ -104,7 +115,6 @@ class Resolver(object): ...@@ -104,7 +115,6 @@ class Resolver(object):
if family == AF_UNSPEC: if family == AF_UNSPEC:
values = Values(self.hub, 2) values = Values(self.hub, 2)
# note, that we assume that ares.gethostbyname does not raise exceptions
ares.gethostbyname(values, host, AF_INET) ares.gethostbyname(values, host, AF_INET)
ares.gethostbyname(values, host, AF_INET6) ares.gethostbyname(values, host, AF_INET6)
elif family == AF_INET: elif family == AF_INET:
...@@ -159,15 +169,14 @@ class Resolver(object): ...@@ -159,15 +169,14 @@ class Resolver(object):
def _gethostbyaddr(self, ip_address): def _gethostbyaddr(self, ip_address):
waiter = Waiter(self.hub) waiter = Waiter(self.hub)
self.ares.gethostbyaddr(waiter, ip_address)
try: try:
self.ares.gethostbyaddr(waiter, ip_address)
return waiter.get() return waiter.get()
except ValueError: except InvalidIP:
ex = sys.exc_info()[1] result = self._getaddrinfo(ip_address, None, family=AF_UNSPEC, socktype=SOCK_DGRAM)
if not str(ex).startswith('illegal IP'): if not result:
raise raise
# socket.gethostbyaddr also accepts domain names; let's do that too _ip_address = result[0][-1][0]
_ip_address = self.gethostbyname(ip_address, 0)
if _ip_address == ip_address: if _ip_address == ip_address:
raise raise
waiter.clear() waiter.clear()
...@@ -175,6 +184,7 @@ class Resolver(object): ...@@ -175,6 +184,7 @@ class Resolver(object):
return waiter.get() return waiter.get()
def gethostbyaddr(self, ip_address): def gethostbyaddr(self, ip_address):
ip_address = _resolve_special(ip_address, AF_UNSPEC)
while True: while True:
ares = self.ares ares = self.ares
try: try:
...@@ -184,24 +194,27 @@ class Resolver(object): ...@@ -184,24 +194,27 @@ class Resolver(object):
raise raise
def _getnameinfo(self, sockaddr, flags): def _getnameinfo(self, sockaddr, flags):
if not isinstance(flags, int):
raise TypeError('an integer is required')
if not isinstance(sockaddr, tuple):
raise TypeError('getnameinfo() argument 1 must be a tuple')
waiter = Waiter(self.hub) waiter = Waiter(self.hub)
self.ares.getnameinfo(waiter, sockaddr, flags) result = self._getaddrinfo(sockaddr[0], str(sockaddr[1]), family=AF_UNSPEC, socktype=SOCK_DGRAM)
try: if not result:
result = waiter.get() raise
except ValueError: elif len(result) != 1:
ex = sys.exc_info()[1] raise error('sockaddr resolved to multiple addresses')
if not str(ex).startswith('illegal IP'): family, socktype, proto, name, address = result[0]
raise
# socket.getnameinfo also accepts domain names; let's do that too if family == AF_INET:
_ip_address = self.gethostbyname(sockaddr[0], 0) if len(sockaddr) != 2:
if _ip_address == sockaddr[0]: raise error("IPv4 sockaddr must be 2 tuple")
raise elif family == AF_INET6:
waiter.clear() address = address[:2] + sockaddr[2:]
self.ares.getnameinfo(waiter, (_ip_address, ) + sockaddr[1:], flags)
result = waiter.get() self.ares.getnameinfo(waiter, address, flags)
if result[1] is None: return waiter.get()
return (result[0], str(sockaddr[1])) + result[2:]
return result
def getnameinfo(self, sockaddr, flags): def getnameinfo(self, sockaddr, flags):
while True: while True:
...@@ -240,3 +253,12 @@ class Values(object): ...@@ -240,3 +253,12 @@ class Values(object):
return self.values return self.values
else: else:
raise self.error raise self.error
def _resolve_special(hostname, family):
if hostname == '':
result = getaddrinfo(None, 0, family, SOCK_DGRAM, 0, AI_PASSIVE)
if len(result) != 1:
raise error('wildcard resolved to multiple address')
return result[0][4][0]
return hostname
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment