From efede5f1d842fa3223b8a751308bf03ceb6125ce Mon Sep 17 00:00:00 2001 From: An Long Date: Thu, 28 Jan 2021 16:06:14 +0800 Subject: [PATCH] translate socket.timeout to TTransportException --- thriftpy2/rpc.py | 15 ++++++++++----- thriftpy2/transport/socket.py | 24 ++++++++++++++++++++++-- thriftpy2/transport/sslsocket.py | 5 +++-- 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/thriftpy2/rpc.py b/thriftpy2/rpc.py index 44d33ab1..2a2fac5c 100644 --- a/thriftpy2/rpc.py +++ b/thriftpy2/rpc.py @@ -31,13 +31,15 @@ def make_client(service, host="localhost", port=9090, unix_socket=None, proto_factory=TBinaryProtocolFactory(), trans_factory=TBufferedTransportFactory(), timeout=3000, cafile=None, ssl_context=None, certfile=None, - keyfile=None, url="", socket_family=socket.AF_INET): + keyfile=None, url="", socket_family=socket.AF_INET, + handle_timeout_error=False): if url: parsed_url = urllib.parse.urlparse(url) host = parsed_url.hostname or host port = parsed_url.port or port if unix_socket: - socket = TSocket(unix_socket=unix_socket, socket_timeout=timeout) + socket = TSocket(unix_socket=unix_socket, socket_timeout=timeout, + ) if certfile: warnings.warn("SSL only works with host:port, not unix_socket.") elif host and port: @@ -47,7 +49,9 @@ def make_client(service, host="localhost", port=9090, unix_socket=None, certfile=certfile, keyfile=keyfile, ssl_context=ssl_context) else: - socket = TSocket(host, port, socket_family=socket_family, socket_timeout=timeout) + socket = TSocket(host, port, socket_family=socket_family, + socket_timeout=timeout, + handle_timeout_error=handle_timeout_error) else: raise ValueError("Either host/port or unix_socket or url must be provided.") @@ -91,7 +95,7 @@ def client_context(service, host="localhost", port=9090, unix_socket=None, trans_factory=TBufferedTransportFactory(), timeout=None, socket_timeout=3000, connect_timeout=3000, cafile=None, ssl_context=None, certfile=None, keyfile=None, - url=""): + url="", handle_timeout_error=False): if url: parsed_url = urllib.parse.urlparse(url) host = parsed_url.hostname or host @@ -119,7 +123,8 @@ def client_context(service, host="localhost", port=9090, unix_socket=None, else: socket = TSocket(host, port, connect_timeout=connect_timeout, - socket_timeout=socket_timeout) + socket_timeout=socket_timeout, + handle_timeout_error=handle_timeout_error) else: raise ValueError("Either host/port or unix_socket or url must be provided.") diff --git a/thriftpy2/transport/socket.py b/thriftpy2/transport/socket.py index 0db0e3ad..dcf2ef4f 100644 --- a/thriftpy2/transport/socket.py +++ b/thriftpy2/transport/socket.py @@ -16,7 +16,8 @@ class TSocket(object): def __init__(self, host=None, port=None, unix_socket=None, sock=None, socket_family=socket.AF_INET, - socket_timeout=3000, connect_timeout=None): + socket_timeout=3000, connect_timeout=None, + handle_timeout_error=False): """Initialize a TSocket TSocket can be initialized in 3 ways: @@ -35,6 +36,8 @@ def __init__(self, host=None, port=None, unix_socket=None, @param socket_timeout socket timeout in ms @param connect_timeout connect timeout in ms, only used in connection, will be set to socket_timeout if not set. + @param handle_timeout_error(bool) Whether translate socket.timeout + error to TTransportException. Default is False for compalibility. """ if sock: self.sock = sock @@ -54,6 +57,8 @@ def __init__(self, host=None, port=None, unix_socket=None, self.connect_timeout = connect_timeout / 1000 if connect_timeout \ else self.socket_timeout + self.handle_timeout_error = handle_timeout_error + def _init_sock(self): if self.unix_socket: _sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) @@ -108,6 +113,13 @@ def read(self, sz): while True: try: buff = self.sock.recv(sz) + except socket.timeout: + if not self.handle_timeout_error: + raise + addr = self.sock.getsockname() + typ = TTransportException.TIMED_OUT + msg = "Timeouted when read from %s" % str(addr) + raise TTransportException(type=typ, message=msg) except socket.error as e: if e.errno == errno.EINTR: continue @@ -133,7 +145,15 @@ def read(self, sz): return buff def write(self, buff): - self.sock.sendall(buff) + try: + self.sock.sendall(buff) + except socket.timeout: + if not self.handle_timeout_error: + raise + addr = self.sock.getsockname() + typ = TTransportException.TIMED_OUT + msg = "Timeouted when write to %s" % str(addr) + raise TTransportException(type=typ, message=msg) def flush(self): pass diff --git a/thriftpy2/transport/sslsocket.py b/thriftpy2/transport/sslsocket.py index 3972e348..5d569a4d 100644 --- a/thriftpy2/transport/sslsocket.py +++ b/thriftpy2/transport/sslsocket.py @@ -23,7 +23,7 @@ def __init__(self, host, port, socket_family=socket.AF_INET, socket_timeout=3000, connect_timeout=None, ssl_context=None, validate=True, cafile=None, capath=None, certfile=None, keyfile=None, - ciphers=DEFAULT_CIPHERS): + ciphers=DEFAULT_CIPHERS, handle_timeout_error=False): """Initialize a TSSLSocket @param validate(bool) Set to False to disable SSL certificate @@ -47,7 +47,8 @@ def __init__(self, host, port, socket_family=socket.AF_INET, """ super(TSSLSocket, self).__init__( host=host, port=port, socket_family=socket_family, - connect_timeout=connect_timeout, socket_timeout=socket_timeout) + connect_timeout=connect_timeout, socket_timeout=socket_timeout, + handle_timeout_error=handle_timeout_error) if ssl_context: self.ssl_context = ssl_context