Skip to content

Commit

Permalink
translate socket.timeout to TTransportException
Browse files Browse the repository at this point in the history
  • Loading branch information
aisk committed Jan 28, 2021
1 parent 108cca5 commit efede5f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 9 deletions.
15 changes: 10 additions & 5 deletions thriftpy2/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ def make_client(service, host="localhost", port=9090, unix_socket=None,
proto_factory=TBinaryProtocolFactory(),
trans_factory=TBufferedTransportFactory(),
timeout=3000, cafile=None, ssl_context=None, certfile=None,
keyfile=None, url="", socket_family=socket.AF_INET):
keyfile=None, url="", socket_family=socket.AF_INET,
handle_timeout_error=False):
if url:
parsed_url = urllib.parse.urlparse(url)
host = parsed_url.hostname or host
port = parsed_url.port or port
if unix_socket:
socket = TSocket(unix_socket=unix_socket, socket_timeout=timeout)
socket = TSocket(unix_socket=unix_socket, socket_timeout=timeout,
)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
Expand All @@ -47,7 +49,9 @@ def make_client(service, host="localhost", port=9090, unix_socket=None,
certfile=certfile, keyfile=keyfile,
ssl_context=ssl_context)
else:
socket = TSocket(host, port, socket_family=socket_family, socket_timeout=timeout)
socket = TSocket(host, port, socket_family=socket_family,
socket_timeout=timeout,
handle_timeout_error=handle_timeout_error)
else:
raise ValueError("Either host/port or unix_socket or url must be provided.")

Expand Down Expand Up @@ -91,7 +95,7 @@ def client_context(service, host="localhost", port=9090, unix_socket=None,
trans_factory=TBufferedTransportFactory(),
timeout=None, socket_timeout=3000, connect_timeout=3000,
cafile=None, ssl_context=None, certfile=None, keyfile=None,
url=""):
url="", handle_timeout_error=False):
if url:
parsed_url = urllib.parse.urlparse(url)
host = parsed_url.hostname or host
Expand Down Expand Up @@ -119,7 +123,8 @@ def client_context(service, host="localhost", port=9090, unix_socket=None,
else:
socket = TSocket(host, port,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout)
socket_timeout=socket_timeout,
handle_timeout_error=handle_timeout_error)
else:
raise ValueError("Either host/port or unix_socket or url must be provided.")

Expand Down
24 changes: 22 additions & 2 deletions thriftpy2/transport/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ class TSocket(object):

def __init__(self, host=None, port=None, unix_socket=None,
sock=None, socket_family=socket.AF_INET,
socket_timeout=3000, connect_timeout=None):
socket_timeout=3000, connect_timeout=None,
handle_timeout_error=False):
"""Initialize a TSocket
TSocket can be initialized in 3 ways:
Expand All @@ -35,6 +36,8 @@ def __init__(self, host=None, port=None, unix_socket=None,
@param socket_timeout socket timeout in ms
@param connect_timeout connect timeout in ms, only used in
connection, will be set to socket_timeout if not set.
@param handle_timeout_error(bool) Whether translate socket.timeout
error to TTransportException. Default is False for compalibility.
"""
if sock:
self.sock = sock
Expand All @@ -54,6 +57,8 @@ def __init__(self, host=None, port=None, unix_socket=None,
self.connect_timeout = connect_timeout / 1000 if connect_timeout \
else self.socket_timeout

self.handle_timeout_error = handle_timeout_error

def _init_sock(self):
if self.unix_socket:
_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
Expand Down Expand Up @@ -108,6 +113,13 @@ def read(self, sz):
while True:
try:
buff = self.sock.recv(sz)
except socket.timeout:
if not self.handle_timeout_error:
raise
addr = self.sock.getsockname()
typ = TTransportException.TIMED_OUT
msg = "Timeouted when read from %s" % str(addr)
raise TTransportException(type=typ, message=msg)
except socket.error as e:
if e.errno == errno.EINTR:
continue
Expand All @@ -133,7 +145,15 @@ def read(self, sz):
return buff

def write(self, buff):
self.sock.sendall(buff)
try:
self.sock.sendall(buff)
except socket.timeout:
if not self.handle_timeout_error:
raise
addr = self.sock.getsockname()
typ = TTransportException.TIMED_OUT
msg = "Timeouted when write to %s" % str(addr)
raise TTransportException(type=typ, message=msg)

def flush(self):
pass
Expand Down
5 changes: 3 additions & 2 deletions thriftpy2/transport/sslsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, host, port, socket_family=socket.AF_INET,
socket_timeout=3000, connect_timeout=None,
ssl_context=None, validate=True,
cafile=None, capath=None, certfile=None, keyfile=None,
ciphers=DEFAULT_CIPHERS):
ciphers=DEFAULT_CIPHERS, handle_timeout_error=False):
"""Initialize a TSSLSocket
@param validate(bool) Set to False to disable SSL certificate
Expand All @@ -47,7 +47,8 @@ def __init__(self, host, port, socket_family=socket.AF_INET,
"""
super(TSSLSocket, self).__init__(
host=host, port=port, socket_family=socket_family,
connect_timeout=connect_timeout, socket_timeout=socket_timeout)
connect_timeout=connect_timeout, socket_timeout=socket_timeout,
handle_timeout_error=handle_timeout_error)

if ssl_context:
self.ssl_context = ssl_context
Expand Down

0 comments on commit efede5f

Please sign in to comment.