Commit e0106942 authored by Denis Bilenko's avatar Denis Bilenko

socket cleanup: remove __getattr__, rename fd to _sock, fix sendall

- store the real socket as '_sock' instead of 'fd'; fd is available as a deprecated alias
- __getattr__ is gone; it screws up subclassing; the delegated methods are implemented directly, similar to how stdlib socket does it
- sendall is fixed to never call _sock.sendall (it used to do it when timeout==0.0) - this won't work for ssl subclasses
  it also does not call time.time() twice in a row anymore
parent 2172dbed
......@@ -78,6 +78,7 @@ else:
import _socket
error = _socket.error
timeout = _socket.timeout
_realsocket = _socket.socket
__socket__ = __import__('socket')
_fileobject = __socket__._fileobject
try:
......@@ -222,18 +223,18 @@ class socket(object):
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None):
if _sock is None:
self.fd = _socket.socket(family, type, proto)
self._sock = _realsocket(family, type, proto)
self.timeout = _socket.getdefaulttimeout()
else:
if hasattr(_sock, '_sock'):
self.fd = _sock._sock
self.timeout = getattr(_sock, 'timeout', None)
if self.timeout is None:
self._sock = _sock._sock
self.timeout = getattr(_sock, 'timeout', False)
if self.timeout is False:
self.timeout = _socket.getdefaulttimeout()
else:
self.fd = _sock
self._sock = _sock
self.timeout = _socket.getdefaulttimeout()
self.fd.setblocking(0)
self._sock.setblocking(0)
def __repr__(self):
return '<%s at %s %s>' % (type(self).__name__, hex(id(self)), self._formatinfo())
......@@ -266,18 +267,17 @@ class socket(object):
return result
@property
def _sock(self):
return self.fd
def __getattr__(self, item):
return getattr(self.fd, item)
def fd(self):
import warnings
warnings.warn("socket.fd is deprecated; use socket._sock", DeprecationWarning, stacklevel=2)
return self._sock
def accept(self):
if self.timeout == 0.0:
return self.fd.accept()
return self._sock.accept()
while True:
try:
res = self.fd.accept()
res = self._sock.accept()
except error, ex:
if ex[0] == errno.EWOULDBLOCK:
res = None
......@@ -285,12 +285,12 @@ class socket(object):
raise
if res is not None:
client, addr = res
return type(self)(_sock=client), addr
wait_read(self.fd.fileno(), timeout=self.timeout)
return socket(_sock=client), addr
wait_read(self._sock.fileno(), timeout=self.timeout)
def close(self):
self.fd = _closedsocket()
dummy = self.fd._dummy
self._sock = _closedsocket()
dummy = self._sock._dummy
for method in _delegate_methods:
setattr(self, method, dummy)
......@@ -298,8 +298,8 @@ class socket(object):
if isinstance(address, tuple) and len(address)==2:
address = gethostbyname(address[0]), address[1]
if self.timeout == 0.0:
return self.fd.connect(address)
sock = self.fd
return self._sock.connect(address)
sock = self._sock
if self.timeout is None:
while True:
err = sock.getsockopt(SOL_SOCKET, SO_ERROR)
......@@ -352,7 +352,7 @@ class socket(object):
def recv(self, *args):
while True:
try:
return self.fd.recv(*args)
return self._sock.recv(*args)
except error, ex:
if ex[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
......@@ -363,77 +363,79 @@ class socket(object):
def recvfrom(self, *args):
while True:
try:
return self.fd.recvfrom(*args)
return self._sock.recvfrom(*args)
except error, ex:
sys.exc_clear()
if ex[0] != EWOULDBLOCK or self.timeout == 0.0:
raise ex
wait_read(self.fd.fileno(), timeout=self.timeout)
wait_read(self._sock.fileno(), timeout=self.timeout)
def recvfrom_into(self, *args):
while True:
try:
return self.fd.recvfrom_into(*args)
return self._sock.recvfrom_into(*args)
except error, ex:
if ex[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
sys.exc_clear()
wait_read(self.fd.fileno(), timeout=self.timeout)
wait_read(self._sock.fileno(), timeout=self.timeout)
def recv_into(self, *args):
while True:
try:
return self.fd.recv_into(*args)
return self._sock.recv_into(*args)
except error, ex:
if ex[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
sys.exc_clear()
wait_read(self.fd.fileno(), timeout=self.timeout)
wait_read(self._sock.fileno(), timeout=self.timeout)
def send(self, data, flags=0, timeout=timeout_default):
if timeout is timeout_default:
timeout = self.timeout
try:
return self.fd.send(data, flags)
return self._sock.send(data, flags)
except error, ex:
if ex[0] != EWOULDBLOCK or timeout == 0.0:
raise
sys.exc_clear()
wait_write(self.fd.fileno(), timeout=timeout)
wait_write(self._sock.fileno(), timeout=timeout)
try:
return self.fd.send(data, flags)
return self._sock.send(data, flags)
except error, ex2:
if ex2[0] == EWOULDBLOCK:
return 0
raise
def sendall(self, data, flags=0):
# this sendall is also reused by GreenSSL, so it must not call self.fd methods directly
# this sendall is also reused by SSL subclasses (both from ssl and sslold modules),
# so it should not call self._sock methods directly
if self.timeout is None:
data_sent = 0
while data_sent < len(data):
data_sent += self.send(data[data_sent:], flags)
elif not self.timeout:
return self.fd.sendall(data)
else:
end = time.time() + self.timeout
timeleft = self.timeout
end = time.time() + timeleft
data_sent = 0
while data_sent < len(data):
while True:
data_sent += self.send(data[data_sent:], flags, timeout=timeleft)
if data_sent >= len(data):
break
timeleft = end - time.time()
if timeleft <= 0:
raise timeout
data_sent += self.send(data[data_sent:], flags, timeout=timeleft)
def sendto(self, *args):
try:
return self.fd.sendto(*args)
return self._sock.sendto(*args)
except error, ex:
if ex[0] != EWOULDBLOCK or timeout == 0.0:
raise
sys.exc_clear()
wait_write(self.fileno(), timeout=self.timeout)
try:
return self.fd.sendto(*args)
return self._sock.sendto(*args)
except error, ex2:
if ex2[0] == EWOULDBLOCK:
return 0
......@@ -459,6 +461,18 @@ class socket(object):
def gettimeout(self):
return self.timeout
family = property(lambda self: self._sock.family, doc="the socket family")
type = property(lambda self: self._sock.type, doc="the socket type")
proto = property(lambda self: self._sock.proto, doc="the socket protocol")
# delegate the functions that we haven't implemented to the real socket object
_s = ("def %s(self, *args): return self._sock.%s(*args)\n\n"
"%s.__doc__ = _realsocket.%s.__doc__\n")
for _m in set(__socket__._socketmethods) - set(locals()):
exec _s % (_m, _m, _m, _m)
del _m, _s
GreenSocket = socket # XXX this alias will be removed
SysCallError_code_mapping = {-1: 8}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment