From ba24c43a4b7688d98193fae0eb88e0e17654a169 Mon Sep 17 00:00:00 2001 From: Walter BONETTI Date: Mon, 29 Jul 2024 17:23:52 -0400 Subject: [PATCH] Fix: missing error handling on PyOpenSSL error Signed-off-by: Walter BONETTI --- src/paho/mqtt/client.py | 1215 +++++++++++++++++++++++++-------------- 1 file changed, 778 insertions(+), 437 deletions(-) diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index b4c69807..6756e5f4 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -35,11 +35,33 @@ import urllib.request import uuid import warnings -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, NamedTuple, Sequence, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Sequence, + Tuple, + Union, + cast, +) from paho.mqtt.packettypes import PacketTypes -from .enums import CallbackAPIVersion, ConnackCode, LogLevel, MessageState, MessageType, MQTTErrorCode, MQTTProtocolVersion, PahoClientMode, _ConnectionState +from .enums import ( + CallbackAPIVersion, + ConnackCode, + LogLevel, + MessageState, + MessageType, + MQTTErrorCode, + MQTTProtocolVersion, + PahoClientMode, + _ConnectionState, +) from .matcher import MQTTMatcher from .properties import Properties from .reasoncodes import ReasonCode, ReasonCodes @@ -54,10 +76,10 @@ def _subject_alt_name_string(cert: X509) -> list: san = [] for i in range(cert.get_extension_count()): ext = cert.get_extension(i) - if ext.get_short_name() == b'subjectAltName': - san_entries = ext.__str__().split(', ') + if ext.get_short_name() == b"subjectAltName": + san_entries = ext.__str__().split(", ") for entry in san_entries: - key, value = entry.split(':', 1) + key, value = entry.split(":", 1) san.append((key.strip(), value.strip())) return san @@ -66,27 +88,37 @@ def _openssl_match_hostname(cert: X509, hostname: str): CertificateError is raised on failure. On success, the function returns nothing. """ if not cert: - raise ValueError("Empty or no certificate. match_hostname needs a certificate.") + raise ValueError( + "Empty or no certificate. match_hostname needs a certificate." + ) dnsnames = [] # Extract subject alternative name (SAN) entries san = _subject_alt_name_string(cert) for key, value in san: - if key == 'DNS': + if key == "DNS": if ssl._dnsname_match(value, hostname): return dnsnames.append(value) if not dnsnames: # TODO: check if no dns entry to use subject - raise ValueError("pyOpenssl match_hostname: using subject is not supported.") + raise ValueError( + "pyOpenssl match_hostname: using subject is not supported." + ) if len(dnsnames) > 1: - raise ssl.CertificateError(f"Hostname {hostname} doesn't match any of {', '.join(map(repr, dnsnames))}") + raise ssl.CertificateError( + f"Hostname {hostname} doesn't match any of {', '.join(map(repr, dnsnames))}" + ) elif len(dnsnames) == 1: - raise ssl.CertificateError(f"Hostname {hostname} doesn't match {dnsnames[0]}") + raise ssl.CertificateError( + f"Hostname {hostname} doesn't match {dnsnames[0]}" + ) else: - raise ssl.CertificateError("No appropriate commonName or subjectAltName fields were found") + raise ssl.CertificateError( + "No appropriate commonName or subjectAltName fields were found" + ) HAS_OPENSSL = True except ImportError: @@ -118,7 +150,6 @@ class _InPacket(TypedDict): to_process: int pos: int - class _OutPacket(TypedDict): command: int mid: int @@ -129,16 +160,11 @@ class _OutPacket(TypedDict): info: MQTTMessageInfo | None class SocketLike(Protocol): - def recv(self, buffer_size: int) -> bytes: - ... - def send(self, buffer: bytes) -> int: - ... - def close(self) -> None: - ... - def fileno(self) -> int: - ... - def setblocking(self, flag: bool) -> None: - ... + def recv(self, buffer_size: int) -> bytes: ... + def send(self, buffer: bytes) -> int: ... + def close(self) -> None: ... + def fileno(self) -> int: ... + def setblocking(self, flag: bool) -> None: ... try: @@ -167,7 +193,7 @@ def setblocking(self, flag: bool) -> None: HAVE_DNS = False -if platform.system() == 'Windows': +if platform.system() == "Windows": EAGAIN = errno.WSAEWOULDBLOCK # type: ignore[attr-defined] else: EAGAIN = errno.EAGAIN @@ -211,7 +237,9 @@ def setblocking(self, flag: bool) -> None: CONNACK_REFUSED_PROTOCOL_VERSION = ConnackCode.CONNACK_REFUSED_PROTOCOL_VERSION CONNACK_REFUSED_IDENTIFIER_REJECTED = ConnackCode.CONNACK_REFUSED_IDENTIFIER_REJECTED CONNACK_REFUSED_SERVER_UNAVAILABLE = ConnackCode.CONNACK_REFUSED_SERVER_UNAVAILABLE -CONNACK_REFUSED_BAD_USERNAME_PASSWORD = ConnackCode.CONNACK_REFUSED_BAD_USERNAME_PASSWORD +CONNACK_REFUSED_BAD_USERNAME_PASSWORD = ( + ConnackCode.CONNACK_REFUSED_BAD_USERNAME_PASSWORD +) CONNACK_REFUSED_NOT_AUTHORIZED = ConnackCode.CONNACK_REFUSED_NOT_AUTHORIZED # Message state @@ -294,16 +322,28 @@ class DisconnectFlags(NamedTuple): """ -CallbackOnConnect_v1_mqtt3 = Callable[["Client", Any, Dict[str, Any], MQTTErrorCode], None] -CallbackOnConnect_v1_mqtt5 = Callable[["Client", Any, Dict[str, Any], ReasonCode, Union[Properties, None]], None] +CallbackOnConnect_v1_mqtt3 = Callable[ + ["Client", Any, Dict[str, Any], MQTTErrorCode], None +] +CallbackOnConnect_v1_mqtt5 = Callable[ + ["Client", Any, Dict[str, Any], ReasonCode, Union[Properties, None]], None +] CallbackOnConnect_v1 = Union[CallbackOnConnect_v1_mqtt5, CallbackOnConnect_v1_mqtt3] -CallbackOnConnect_v2 = Callable[["Client", Any, ConnectFlags, ReasonCode, Union[Properties, None]], None] +CallbackOnConnect_v2 = Callable[ + ["Client", Any, ConnectFlags, ReasonCode, Union[Properties, None]], None +] CallbackOnConnect = Union[CallbackOnConnect_v1, CallbackOnConnect_v2] CallbackOnConnectFail = Callable[["Client", Any], None] CallbackOnDisconnect_v1_mqtt3 = Callable[["Client", Any, MQTTErrorCode], None] -CallbackOnDisconnect_v1_mqtt5 = Callable[["Client", Any, Union[ReasonCode, int, None], Union[Properties, None]], None] -CallbackOnDisconnect_v1 = Union[CallbackOnDisconnect_v1_mqtt3, CallbackOnDisconnect_v1_mqtt5] -CallbackOnDisconnect_v2 = Callable[["Client", Any, DisconnectFlags, ReasonCode, Union[Properties, None]], None] +CallbackOnDisconnect_v1_mqtt5 = Callable[ + ["Client", Any, Union[ReasonCode, int, None], Union[Properties, None]], None +] +CallbackOnDisconnect_v1 = Union[ + CallbackOnDisconnect_v1_mqtt3, CallbackOnDisconnect_v1_mqtt5 +] +CallbackOnDisconnect_v2 = Callable[ + ["Client", Any, DisconnectFlags, ReasonCode, Union[Properties, None]], None +] CallbackOnDisconnect = Union[CallbackOnDisconnect_v1, CallbackOnDisconnect_v2] CallbackOnLog = Callable[["Client", Any, int, str], None] CallbackOnMessage = Callable[["Client", Any, "MQTTMessage"], None] @@ -313,14 +353,26 @@ class DisconnectFlags(NamedTuple): CallbackOnPublish = Union[CallbackOnPublish_v1, CallbackOnPublish_v2] CallbackOnSocket = Callable[["Client", Any, "SocketLike"], None] CallbackOnSubscribe_v1_mqtt3 = Callable[["Client", Any, int, Tuple[int, ...]], None] -CallbackOnSubscribe_v1_mqtt5 = Callable[["Client", Any, int, List[ReasonCode], Properties], None] -CallbackOnSubscribe_v1 = Union[CallbackOnSubscribe_v1_mqtt3, CallbackOnSubscribe_v1_mqtt5] -CallbackOnSubscribe_v2 = Callable[["Client", Any, int, List[ReasonCode], Union[Properties, None]], None] +CallbackOnSubscribe_v1_mqtt5 = Callable[ + ["Client", Any, int, List[ReasonCode], Properties], None +] +CallbackOnSubscribe_v1 = Union[ + CallbackOnSubscribe_v1_mqtt3, CallbackOnSubscribe_v1_mqtt5 +] +CallbackOnSubscribe_v2 = Callable[ + ["Client", Any, int, List[ReasonCode], Union[Properties, None]], None +] CallbackOnSubscribe = Union[CallbackOnSubscribe_v1, CallbackOnSubscribe_v2] CallbackOnUnsubscribe_v1_mqtt3 = Callable[["Client", Any, int], None] -CallbackOnUnsubscribe_v1_mqtt5 = Callable[["Client", Any, int, Properties, Union[ReasonCode, List[ReasonCode]]], None] -CallbackOnUnsubscribe_v1 = Union[CallbackOnUnsubscribe_v1_mqtt3, CallbackOnUnsubscribe_v1_mqtt5] -CallbackOnUnsubscribe_v2 = Callable[["Client", Any, int, List[ReasonCode], Union[Properties, None]], None] +CallbackOnUnsubscribe_v1_mqtt5 = Callable[ + ["Client", Any, int, Properties, Union[ReasonCode, List[ReasonCode]]], None +] +CallbackOnUnsubscribe_v1 = Union[ + CallbackOnUnsubscribe_v1_mqtt3, CallbackOnUnsubscribe_v1_mqtt5 +] +CallbackOnUnsubscribe_v2 = Callable[ + ["Client", Any, int, List[ReasonCode], Union[Properties, None]], None +] CallbackOnUnsubscribe = Union[CallbackOnUnsubscribe_v1, CallbackOnUnsubscribe_v2] # This is needed for typing because class Client redefined the name "socket" @@ -328,10 +380,11 @@ class DisconnectFlags(NamedTuple): class WebsocketConnectionError(ConnectionError): - """ WebsocketConnectionError is a subclass of ConnectionError. + """WebsocketConnectionError is a subclass of ConnectionError. - It's raised when unable to perform the Websocket handshake. + It's raised when unable to perform the Websocket handshake. """ + pass @@ -375,7 +428,7 @@ def error_string(mqtt_errno: MQTTErrorCode | int) -> str: return "Unknown error." -def connack_string(connack_code: int|ReasonCode) -> str: +def connack_string(connack_code: int | ReasonCode) -> str: """Return the string associated with a CONNACK result or CONNACK reason code.""" if isinstance(connack_code, ReasonCode): return str(connack_code) @@ -464,7 +517,7 @@ def _base62( num, rest = divmod(num, 62) digits.append(base[rest]) digits.extend(base[0] for _ in range(len(digits), padding)) - return ''.join(reversed(digits)) + return "".join(reversed(digits)) def topic_matches_sub(sub: str, topic: str) -> bool: @@ -486,15 +539,13 @@ def topic_matches_sub(sub: str, topic: str) -> bool: def _socketpair_compat() -> tuple[socket.socket, socket.socket]: """TCP/IP socketpair including Windows support""" - listensock = socket.socket( - socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) + listensock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) listensock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) listensock.bind(("127.0.0.1", 0)) listensock.listen(1) iface, port = listensock.getsockname() - sock1 = socket.socket( - socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) + sock1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) sock1.setblocking(False) try: sock1.connect(("127.0.0.1", port)) @@ -512,7 +563,9 @@ def _force_bytes(s: str | bytes) -> bytes: return s -def _encode_payload(payload: str | bytes | bytearray | int | float | None) -> bytes|bytearray: +def _encode_payload( + payload: str | bytes | bytearray | int | float | None, +) -> bytes | bytearray: if isinstance(payload, str): return payload.encode("utf-8") @@ -523,9 +576,7 @@ def _encode_payload(payload: str | bytes | bytearray | int | float | None) -> by return b"" if not isinstance(payload, (bytes, bytearray)): - raise TypeError( - "payload must be a string, bytearray, int, float or None." - ) + raise TypeError("payload must be a string, bytearray, int, float or None.") return payload @@ -536,7 +587,7 @@ class MQTTMessageInfo: message has been published, and/or wait until it is published. """ - __slots__ = 'mid', '_published', '_condition', 'rc', '_iterpos' + __slots__ = "mid", "_published", "_condition", "rc", "_iterpos" def __init__(self, mid: int): self.mid = mid @@ -594,14 +645,15 @@ def wait_for_publish(self, timeout: float | None = None) -> None: reason. """ if self.rc == MQTT_ERR_QUEUE_SIZE: - raise ValueError('Message is not queued due to ERR_QUEUE_SIZE') + raise ValueError("Message is not queued due to ERR_QUEUE_SIZE") elif self.rc == MQTT_ERR_AGAIN: pass elif self.rc > 0: - raise RuntimeError(f'Message publish failed: {error_string(self.rc)}') + raise RuntimeError(f"Message publish failed: {error_string(self.rc)}") timeout_time = None if timeout is None else time_func() + timeout - timeout_tenth = None if timeout is None else timeout / 10. + timeout_tenth = None if timeout is None else timeout / 10.0 + def timed_out() -> bool: return False if timeout_time is None else time_func() > timeout_time @@ -610,7 +662,7 @@ def timed_out() -> bool: self._condition.wait(timeout_tenth) if self.rc > 0: - raise RuntimeError(f'Message publish failed: {error_string(self.rc)}') + raise RuntimeError(f"Message publish failed: {error_string(self.rc)}") def is_published(self) -> bool: """Returns True if the message associated with this object has been @@ -619,21 +671,33 @@ def is_published(self) -> bool: To wait for this to become true, look at `wait_for_publish`. """ if self.rc == MQTTErrorCode.MQTT_ERR_QUEUE_SIZE: - raise ValueError('Message is not queued due to ERR_QUEUE_SIZE') + raise ValueError("Message is not queued due to ERR_QUEUE_SIZE") elif self.rc == MQTTErrorCode.MQTT_ERR_AGAIN: pass elif self.rc > 0: - raise RuntimeError(f'Message publish failed: {error_string(self.rc)}') + raise RuntimeError(f"Message publish failed: {error_string(self.rc)}") with self._condition: return self._published class MQTTMessage: - """ This is a class that describes an incoming message. It is + """This is a class that describes an incoming message. It is passed to the `on_message` callback as the message parameter. """ - __slots__ = 'timestamp', 'state', 'dup', 'mid', '_topic', 'payload', 'qos', 'retain', 'info', 'properties' + + __slots__ = ( + "timestamp", + "state", + "dup", + "mid", + "_topic", + "payload", + "qos", + "retain", + "info", + "properties", + ) def __init__(self, mid: int = 0, topic: bytes = b""): self.timestamp = 0.0 @@ -668,7 +732,7 @@ def topic(self) -> str: This property is read-only. """ - return self._topic.decode('utf-8') + return self._topic.decode("utf-8") @topic.setter def topic(self, value: bytes) -> None: @@ -793,7 +857,8 @@ def __init__( raise ValueError('"unix" transport not supported') elif transport not in ("websockets", "tcp", "unix"): raise ValueError( - f'transport must be "websockets", "tcp" or "unix", not {transport}') + f'transport must be "websockets", "tcp" or "unix", not {transport}' + ) self._manual_ack = manual_ack self._transport = transport @@ -826,13 +891,14 @@ def __init__( if protocol == MQTTv5: if clean_session is not None: - raise ValueError('Clean session is not used for MQTT 5.0') + raise ValueError("Clean session is not used for MQTT 5.0") else: if clean_session is None: clean_session = True if not clean_session and (client_id == "" or client_id is None): raise ValueError( - 'A client id must be provided if clean session is False.') + "A client id must be provided if clean session is False." + ) self._clean_session = clean_session # [MQTT-3.1.3-4] Client Id must be UTF-8 encoded string. @@ -866,12 +932,12 @@ def __init__( self._ping_t = 0.0 self._last_mid = 0 self._state = _ConnectionState.MQTT_CS_NEW - self._out_messages: collections.OrderedDict[ - int, MQTTMessage - ] = collections.OrderedDict() - self._in_messages: collections.OrderedDict[ - int, MQTTMessage - ] = collections.OrderedDict() + self._out_messages: collections.OrderedDict[int, MQTTMessage] = ( + collections.OrderedDict() + ) + self._in_messages: collections.OrderedDict[int, MQTTMessage] = ( + collections.OrderedDict() + ) self._max_inflight_messages = 20 self._inflight_messages = 0 self._max_queued_messages = 0 @@ -921,7 +987,7 @@ def __init__( self._websocket_extra_headers: WebSocketHeaders | None = None # for clean_start == MQTT_CLEAN_START_FIRST_ONLY self._mqttv5_first_connect = True - self.suppress_exceptions = False # For callbacks + self.suppress_exceptions = False # For callbacks def __del__(self) -> None: self._reset_sockets() @@ -938,7 +1004,9 @@ def host(self) -> str: @host.setter def host(self, value: str) -> None: if not self._connection_closed(): - raise RuntimeError("updating host on established connection is not supported") + raise RuntimeError( + "updating host on established connection is not supported" + ) if not value: raise ValueError("Invalid host.") @@ -956,7 +1024,9 @@ def port(self) -> int: @port.setter def port(self, value: int) -> None: if not self._connection_closed(): - raise RuntimeError("updating port on established connection is not supported") + raise RuntimeError( + "updating port on established connection is not supported" + ) if value <= 0: raise ValueError("Invalid port number.") @@ -976,7 +1046,9 @@ def keepalive(self, value: int) -> None: if not self._connection_closed(): # The issue here is that the previous value of keepalive matter to possibly # sent ping packet. - raise RuntimeError("updating keepalive on established connection is not supported") + raise RuntimeError( + "updating keepalive on established connection is not supported" + ) if value < 0: raise ValueError("Keepalive must be >=0.") @@ -995,7 +1067,9 @@ def transport(self) -> Literal["tcp", "websockets", "unix"]: @transport.setter def transport(self, value: Literal["tcp", "websockets"]) -> None: if not self._connection_closed(): - raise RuntimeError("updating transport on established connection is not supported") + raise RuntimeError( + "updating transport on established connection is not supported" + ) self._transport = value @@ -1020,7 +1094,9 @@ def connect_timeout(self) -> float: @connect_timeout.setter def connect_timeout(self, value: float) -> None: if not self._connection_closed(): - raise RuntimeError("updating connect_timeout on established connection is not supported") + raise RuntimeError( + "updating connect_timeout on established connection is not supported" + ) if value <= 0.0: raise ValueError("timeout must be a positive number") @@ -1040,7 +1116,9 @@ def username(self) -> str | None: @username.setter def username(self, value: str | None) -> None: if not self._connection_closed(): - raise RuntimeError("updating username on established connection is not supported") + raise RuntimeError( + "updating username on established connection is not supported" + ) if value is None: self._username = None @@ -1060,7 +1138,9 @@ def password(self) -> str | None: @password.setter def password(self, value: str | None) -> None: if not self._connection_closed(): - raise RuntimeError("updating password on established connection is not supported") + raise RuntimeError( + "updating password on established connection is not supported" + ) if value is None: self._password = None @@ -1081,7 +1161,9 @@ def max_inflight_messages(self, value: int) -> None: if not self._connection_closed(): # Not tested. Some doubt that everything is okay when max_inflight change between 0 # and > 0 value because _update_inflight is skipped when _max_inflight_messages == 0 - raise RuntimeError("updating max_inflight_messages on established connection is not supported") + raise RuntimeError( + "updating max_inflight_messages on established connection is not supported" + ) if value < 0: raise ValueError("Invalid inflight.") @@ -1101,7 +1183,9 @@ def max_queued_messages(self) -> int: def max_queued_messages(self, value: int) -> None: if not self._connection_closed(): # Not tested. - raise RuntimeError("updating max_queued_messages on established connection is not supported") + raise RuntimeError( + "updating max_queued_messages on established connection is not supported" + ) if value < 0: raise ValueError("Invalid queue size.") @@ -1140,32 +1224,64 @@ def logger(self, value: logging.Logger | None) -> None: def _sock_recv(self, bufsize: int) -> bytes: if self._sock is None: raise ConnectionError("self._sock is None") - try: - return self._sock.recv(bufsize) - except ssl.SSLWantReadError as err: - raise BlockingIOError() from err - except ssl.SSLWantWriteError as err: - self._call_socket_register_write() - raise BlockingIOError() from err - except AttributeError as err: - self._easy_log( - MQTT_LOG_DEBUG, "socket was None: %s", err) - raise ConnectionError() from err + + if HAS_OPENSSL and isinstance(self._ssl_context, SSL.Context): + from OpenSSL import SSL + + try: + return self._sock.recv(bufsize) + except SSL.WantReadError as err: + raise BlockingIOError() from err + except SSL.WantWriteError as err: + self._call_socket_register_write() + raise BlockingIOError() from err + except SSL.ZeroReturnError as err: + raise BlockingIOError() from err + except AttributeError as err: + self._easy_log(MQTT_LOG_DEBUG, "socket was None: %s", err) + raise ConnectionError() from err + else: + try: + return self._sock.recv(bufsize) + except ssl.SSLWantReadError as err: + raise BlockingIOError() from err + except ssl.SSLWantWriteError as err: + self._call_socket_register_write() + raise BlockingIOError() from err + except AttributeError as err: + self._easy_log(MQTT_LOG_DEBUG, "socket was None: %s", err) + raise ConnectionError() from err def _sock_send(self, buf: bytes) -> int: if self._sock is None: raise ConnectionError("self._sock is None") - try: - return self._sock.send(buf) - except ssl.SSLWantReadError as err: - raise BlockingIOError() from err - except ssl.SSLWantWriteError as err: - self._call_socket_register_write() - raise BlockingIOError() from err - except BlockingIOError as err: - self._call_socket_register_write() - raise BlockingIOError() from err + if HAS_OPENSSL and isinstance(self._ssl_context, SSL.Context): + from OpenSSL import SSL + + try: + return self._sock.send(buf) + except SSL.WantReadError as err: + raise BlockingIOError() from err + except SSL.WantWriteError as err: + self._call_socket_register_write() + raise BlockingIOError() from err + except SSL.ZeroReturnError as err: + raise BlockingIOError() from err + except BlockingIOError as err: + self._call_socket_register_write() + raise BlockingIOError() from err + else: + try: + return self._sock.send(buf) + except ssl.SSLWantReadError as err: + raise BlockingIOError() from err + except ssl.SSLWantWriteError as err: + self._call_socket_register_write() + raise BlockingIOError() from err + except BlockingIOError as err: + self._call_socket_register_write() + raise BlockingIOError() from err def _sock_close(self) -> None: """Close the connection to the server.""" @@ -1207,7 +1323,7 @@ def ws_set_options( path: str = "/mqtt", headers: WebSocketHeaders | None = None, ) -> None: - """ Set the path and headers for a websocket connection + """Set the path and headers for a websocket connection :param str path: a string starting with / which should be the endpoint of the mqtt connection on the remote server @@ -1224,7 +1340,8 @@ def ws_set_options( self._websocket_extra_headers = headers else: raise ValueError( - "'headers' option to ws_set_options has to be either a dictionary or callable") + "'headers' option to ws_set_options has to be either a dictionary or callable" + ) def tls_set_context( self, @@ -1238,7 +1355,7 @@ def tls_set_context( Must be called before `connect()`, `connect_async()` or `connect_srv()`. """ if self._ssl_context is not None: - raise ValueError('SSL/TLS has already been configured.') + raise ValueError("SSL/TLS has already been configured.") if context is None: if HAS_OPENSSL: @@ -1250,7 +1367,7 @@ def tls_set_context( self._ssl_context = context # Ensure _tls_insecure is consistent with check_hostname attribute for ssl.SSLContext - if isinstance(context, ssl.SSLContext) and hasattr(context, 'check_hostname'): + if isinstance(context, ssl.SSLContext) and hasattr(context, "check_hostname"): self._tls_insecure = not context.check_hostname elif HAS_OPENSSL and isinstance(context, SSL.Context): # PyOpenSSL Context does not have check_hostname attribute @@ -1314,15 +1431,16 @@ def tls_set( Must be called before `connect()`, `connect_async()` or `connect_srv()`.""" if ssl is None: - raise ValueError('This platform has no SSL/TLS.') + raise ValueError("This platform has no SSL/TLS.") - if not hasattr(ssl, 'SSLContext'): + if not hasattr(ssl, "SSLContext"): # Require Python version that has SSL context support in standard library raise ValueError( - 'Python 2.7.9 and 3.2 are the minimum supported versions for TLS.') + "Python 2.7.9 and 3.2 are the minimum supported versions for TLS." + ) - if ca_certs is None and not hasattr(ssl.SSLContext, 'load_default_certs'): - raise ValueError('ca_certs must not be None.') + if ca_certs is None and not hasattr(ssl.SSLContext, "load_default_certs"): + raise ValueError("ca_certs must not be None.") # Create SSLContext object if tls_version is None: @@ -1342,7 +1460,7 @@ def tls_set( if certfile is not None: context.load_cert_chain(certfile, keyfile, keyfile_password) - if cert_reqs == ssl.CERT_NONE and hasattr(context, 'check_hostname'): + if cert_reqs == ssl.CERT_NONE and hasattr(context, "check_hostname"): context.check_hostname = False context.verify_mode = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs @@ -1384,12 +1502,13 @@ def tls_insecure_set(self, value: bool) -> None: if self._ssl_context is None: raise ValueError( - 'Must configure SSL context before using tls_insecure_set.') + "Must configure SSL context before using tls_insecure_set." + ) self._tls_insecure = value # Ensure check_hostname is consistent with _tls_insecure attribute - if hasattr(self._ssl_context, 'check_hostname'): + if hasattr(self._ssl_context, "check_hostname"): # Rely on SSLContext to check host name # If verify_mode is CERT_NONE then the host name will never be checked self._ssl_context.check_hostname = not value @@ -1488,8 +1607,9 @@ def connect( if properties: raise ValueError("Properties only apply to MQTT V5") - self.connect_async(host, port, keepalive, - bind_address, bind_port, clean_start, properties) + self.connect_async( + host, port, keepalive, bind_address, bind_port, clean_start, properties + ) return self.reconnect() def connect_srv( @@ -1510,23 +1630,27 @@ def connect_srv( if HAVE_DNS is False: raise ValueError( - 'No DNS resolver library found, try "pip install dnspython".') + 'No DNS resolver library found, try "pip install dnspython".' + ) if domain is None: domain = socket.getfqdn() - domain = domain[domain.find('.') + 1:] + domain = domain[domain.find(".") + 1 :] try: - rr = f'_mqtt._tcp.{domain}' + rr = f"_mqtt._tcp.{domain}" if self._ssl: # IANA specifies secure-mqtt (not mqtts) for port 8883 - rr = f'_secure-mqtt._tcp.{domain}' + rr = f"_secure-mqtt._tcp.{domain}" answers = [] for answer in dns.resolver.query(rr, dns.rdatatype.SRV): addr = answer.target.to_text()[:-1] - answers.append( - (addr, answer.port, answer.priority, answer.weight)) - except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer, dns.resolver.NoNameservers) as err: + answers.append((addr, answer.port, answer.priority, answer.weight)) + except ( + dns.resolver.NXDOMAIN, + dns.resolver.NoAnswer, + dns.resolver.NoNameservers, + ) as err: raise ValueError(f"No answer/NXDOMAIN for SRV in {domain}") from err # FIXME: doesn't account for weight @@ -1534,7 +1658,15 @@ def connect_srv( host, port, prio, weight = answer try: - return self.connect(host, port, keepalive, bind_address, bind_port, clean_start, properties) + return self.connect( + host, + port, + keepalive, + bind_address, + bind_port, + clean_start, + properties, + ) except Exception: # noqa: S110 pass @@ -1573,7 +1705,7 @@ def connect_async( MQTT connect packet. """ if bind_port < 0: - raise ValueError('Invalid bind port number.') + raise ValueError("Invalid bind port number.") # Switch to state NEW to allow update of host, port & co. self._sock_close() @@ -1589,12 +1721,12 @@ def connect_async( self._state = _ConnectionState.MQTT_CS_CONNECT_ASYNC def reconnect_delay_set(self, min_delay: int = 1, max_delay: int = 120) -> None: - """ Configure the exponential reconnect delay + """Configure the exponential reconnect delay - When connection is lost, wait initially min_delay seconds and - double this time every attempt. The wait is capped at max_delay. - Once the client is fully connected (e.g. not only TCP socket, but - received a success CONNACK), the wait timer is reset to min_delay. + When connection is lost, wait initially min_delay seconds and + double this time every attempt. The wait is capped at max_delay. + Once the client is fully connected (e.g. not only TCP socket, but + received a success CONNACK), the wait timer is reset to min_delay. """ with self._reconnect_delay_mutex: self._reconnect_min_delay = min_delay @@ -1605,9 +1737,9 @@ def reconnect(self) -> MQTTErrorCode: """Reconnect the client after a disconnect. Can only be called after connect()/connect_async().""" if len(self._host) == 0: - raise ValueError('Invalid host.') + raise ValueError("Invalid host.") if self._port <= 0: - raise ValueError('Invalid port number.') + raise ValueError("Invalid port number.") self._in_packet = { "command": 0, @@ -1628,7 +1760,11 @@ def reconnect(self) -> MQTTErrorCode: # Mark all currently outgoing QoS = 0 packets as lost, # or `wait_for_publish()` could hang forever for pkt in self._out_packet: - if pkt["command"] & 0xF0 == PUBLISH and pkt["qos"] == 0 and pkt["info"] is not None: + if ( + pkt["command"] & 0xF0 == PUBLISH + and pkt["qos"] == 0 + and pkt["info"] is not None + ): pkt["info"].rc = MQTT_ERR_CONN_LOST pkt["info"]._set_as_published() @@ -1649,7 +1785,8 @@ def reconnect(self) -> MQTTErrorCode: on_pre_connect(self, self._userdata) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_pre_connect: %s', err) + MQTT_LOG_ERR, "Caught exception in on_pre_connect: %s", err + ) if not self.suppress_exceptions: raise @@ -1694,7 +1831,7 @@ def loop(self, timeout: float = 1.0) -> MQTTErrorCode: def _loop(self, timeout: float = 1.0) -> MQTTErrorCode: if timeout < 0.0: - raise ValueError('Invalid timeout.') + raise ValueError("Invalid timeout.") if self.want_write(): wlist = [self._sock] @@ -1703,7 +1840,7 @@ def _loop(self, timeout: float = 1.0) -> MQTTErrorCode: # used to check if there are any bytes left in the (SSL) socket pending_bytes = 0 - if hasattr(self._sock, 'pending'): + if hasattr(self._sock, "pending"): pending_bytes = self._sock.pending() # type: ignore[union-attr] # if bytes are pending do not wait in select @@ -1726,13 +1863,19 @@ def _loop(self, timeout: float = 1.0) -> MQTTErrorCode: # call _loop(). We still want to break that loop by returning an # rc != MQTT_ERR_SUCCESS and we don't want state to change from # mqtt_cs_disconnecting. - if self._state not in (_ConnectionState.MQTT_CS_DISCONNECTING, _ConnectionState.MQTT_CS_DISCONNECTED): + if self._state not in ( + _ConnectionState.MQTT_CS_DISCONNECTING, + _ConnectionState.MQTT_CS_DISCONNECTED, + ): self._state = _ConnectionState.MQTT_CS_CONNECTION_LOST return MQTTErrorCode.MQTT_ERR_CONN_LOST except ValueError: # Can occur if we just reconnected but rlist/wlist contain a -1 for # some reason. - if self._state not in (_ConnectionState.MQTT_CS_DISCONNECTING, _ConnectionState.MQTT_CS_DISCONNECTED): + if self._state not in ( + _ConnectionState.MQTT_CS_DISCONNECTING, + _ConnectionState.MQTT_CS_DISCONNECTED, + ): self._state = _ConnectionState.MQTT_CS_CONNECTION_LOST return MQTTErrorCode.MQTT_ERR_CONN_LOST except Exception: @@ -1811,26 +1954,34 @@ def publish( """ if self._protocol != MQTTv5: if topic is None or len(topic) == 0: - raise ValueError('Invalid topic.') + raise ValueError("Invalid topic.") - topic_bytes = topic.encode('utf-8') + topic_bytes = topic.encode("utf-8") self._raise_for_invalid_topic(topic_bytes) if qos < 0 or qos > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") local_payload = _encode_payload(payload) if len(local_payload) > 268435455: - raise ValueError('Payload too large.') + raise ValueError("Payload too large.") local_mid = self._mid_generate() if qos == 0: info = MQTTMessageInfo(local_mid) rc = self._send_publish( - local_mid, topic_bytes, local_payload, qos, retain, False, info, properties) + local_mid, + topic_bytes, + local_payload, + qos, + retain, + False, + info, + properties, + ) info.rc = rc return info else: @@ -1843,7 +1994,10 @@ def publish( message.properties = properties with self._out_message_mutex: - if self._max_queued_messages > 0 and len(self._out_messages) >= self._max_queued_messages: + if ( + self._max_queued_messages > 0 + and len(self._out_messages) >= self._max_queued_messages + ): message.info.rc = MQTTErrorCode.MQTT_ERR_QUEUE_SIZE return message.info @@ -1852,15 +2006,26 @@ def publish( return message.info self._out_messages[message.mid] = message - if self._max_inflight_messages == 0 or self._inflight_messages < self._max_inflight_messages: + if ( + self._max_inflight_messages == 0 + or self._inflight_messages < self._max_inflight_messages + ): self._inflight_messages += 1 if qos == 1: message.state = mqtt_ms_wait_for_puback elif qos == 2: message.state = mqtt_ms_wait_for_pubrec - rc = self._send_publish(message.mid, topic_bytes, message.payload, message.qos, message.retain, - message.dup, message.info, message.properties) + rc = self._send_publish( + message.mid, + topic_bytes, + message.payload, + message.qos, + message.retain, + message.dup, + message.info, + message.properties, + ) # remove from inflight messages so it will be send after a connection is made if rc == MQTTErrorCode.MQTT_ERR_NO_CONN: @@ -1890,9 +2055,9 @@ def username_pw_set( """ # [MQTT-3.1.3-11] User name must be UTF-8 encoded string - self._username = None if username is None else username.encode('utf-8') + self._username = None if username is None else username.encode("utf-8") if isinstance(password, str): - self._password = password.encode('utf-8') + self._password = password.encode("utf-8") else: self._password = password @@ -1916,9 +2081,14 @@ def _connection_closed(self) -> bool: """ Return true if the connection is closed (and not trying to be opened). """ - return ( - self._state == _ConnectionState.MQTT_CS_NEW - or (self._state in (_ConnectionState.MQTT_CS_DISCONNECTING, _ConnectionState.MQTT_CS_DISCONNECTED) and self._sock is None)) + return self._state == _ConnectionState.MQTT_CS_NEW or ( + self._state + in ( + _ConnectionState.MQTT_CS_DISCONNECTING, + _ConnectionState.MQTT_CS_DISCONNECTED, + ) + and self._sock is None + ) def is_connected(self) -> bool: """Returns the current status of the connection @@ -1951,7 +2121,13 @@ def disconnect( def subscribe( self, - topic: str | tuple[str, int] | tuple[str, SubscribeOptions] | list[tuple[str, int]] | list[tuple[str, SubscribeOptions]], + topic: ( + str + | tuple[str, int] + | tuple[str, SubscribeOptions] + | list[tuple[str, int]] + | list[tuple[str, SubscribeOptions]] + ), qos: int = 0, options: SubscribeOptions | None = None, properties: Properties | None = None, @@ -2042,53 +2218,59 @@ def subscribe( topic, options = topic # type: ignore if not isinstance(options, SubscribeOptions): raise ValueError( - 'Subscribe options must be instance of SubscribeOptions class.') + "Subscribe options must be instance of SubscribeOptions class." + ) else: topic, qos = topic # type: ignore if isinstance(topic, (bytes, str)): if qos < 0 or qos > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") if self._protocol == MQTTv5: if options is None: # if no options are provided, use the QoS passed instead options = SubscribeOptions(qos=qos) elif qos != 0: raise ValueError( - 'Subscribe options and qos parameters cannot be combined.') + "Subscribe options and qos parameters cannot be combined." + ) if not isinstance(options, SubscribeOptions): raise ValueError( - 'Subscribe options must be instance of SubscribeOptions class.') - topic_qos_list = [(topic.encode('utf-8'), options)] + "Subscribe options must be instance of SubscribeOptions class." + ) + topic_qos_list = [(topic.encode("utf-8"), options)] else: if topic is None or len(topic) == 0: - raise ValueError('Invalid topic.') - topic_qos_list = [(topic.encode('utf-8'), qos)] # type: ignore + raise ValueError("Invalid topic.") + topic_qos_list = [(topic.encode("utf-8"), qos)] # type: ignore elif isinstance(topic, list): if len(topic) == 0: - raise ValueError('Empty topic list') + raise ValueError("Empty topic list") topic_qos_list = [] if self._protocol == MQTTv5: for t, o in topic: if not isinstance(o, SubscribeOptions): # then the second value should be QoS if o < 0 or o > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") o = SubscribeOptions(qos=o) - topic_qos_list.append((t.encode('utf-8'), o)) + topic_qos_list.append((t.encode("utf-8"), o)) else: for t, q in topic: if isinstance(q, SubscribeOptions) or q < 0 or q > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") if t is None or len(t) == 0 or not isinstance(t, (bytes, str)): - raise ValueError('Invalid topic.') - topic_qos_list.append((t.encode('utf-8'), q)) # type: ignore + raise ValueError("Invalid topic.") + topic_qos_list.append((t.encode("utf-8"), q)) # type: ignore if topic_qos_list is None: raise ValueError("No topic specified, or incorrect topic type.") - if any(self._filter_wildcard_len_check(topic) != MQTT_ERR_SUCCESS for topic, _ in topic_qos_list): - raise ValueError('Invalid subscription filter.') + if any( + self._filter_wildcard_len_check(topic) != MQTT_ERR_SUCCESS + for topic, _ in topic_qos_list + ): + raise ValueError("Invalid subscription filter.") if self._sock is None: return (MQTT_ERR_NO_CONN, None) @@ -2117,17 +2299,17 @@ def unsubscribe( """ topic_list = None if topic is None: - raise ValueError('Invalid topic.') + raise ValueError("Invalid topic.") if isinstance(topic, (bytes, str)): if len(topic) == 0: - raise ValueError('Invalid topic.') - topic_list = [topic.encode('utf-8')] + raise ValueError("Invalid topic.") + topic_list = [topic.encode("utf-8")] elif isinstance(topic, list): topic_list = [] for t in topic: if len(t) == 0 or not isinstance(t, (bytes, str)): - raise ValueError('Invalid topic.') - topic_list.append(t.encode('utf-8')) + raise ValueError("Invalid topic.") + topic_list.append(t.encode("utf-8")) if topic_list is None: raise ValueError("No topic specified, or incorrect topic type.") @@ -2211,7 +2393,10 @@ def loop_misc(self) -> MQTTErrorCode: # This hasn't happened in the keepalive time so we should disconnect. self._sock_close() - if self._state in (_ConnectionState.MQTT_CS_DISCONNECTING, _ConnectionState.MQTT_CS_DISCONNECTED): + if self._state in ( + _ConnectionState.MQTT_CS_DISCONNECTING, + _ConnectionState.MQTT_CS_DISCONNECTED, + ): self._state = _ConnectionState.MQTT_CS_DISCONNECTED rc = MQTTErrorCode.MQTT_ERR_SUCCESS else: @@ -2236,7 +2421,7 @@ def max_queued_messages_set(self, queue_size: int) -> Client: """Set the maximum number of messages in the outgoing message queue. 0 means unlimited.""" if not isinstance(queue_size, int): - raise ValueError('Invalid type of queue size.') + raise ValueError("Invalid type of queue size.") self.max_queued_messages = queue_size return self @@ -2279,24 +2464,25 @@ def will_set( for example by calling `disconnect()`. """ if topic is None or len(topic) == 0: - raise ValueError('Invalid topic.') + raise ValueError("Invalid topic.") if qos < 0 or qos > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") if properties and not isinstance(properties, Properties): raise ValueError( - "The properties argument must be an instance of the Properties class.") + "The properties argument must be an instance of the Properties class." + ) self._will_payload = _encode_payload(payload) self._will = True - self._will_topic = topic.encode('utf-8') + self._will_topic = topic.encode("utf-8") self._will_qos = qos self._will_retain = retain self._will_properties = properties def will_clear(self) -> None: - """ Removes a will that was previously configured with `will_set()`. + """Removes a will that was previously configured with `will_set()`. Must be called before connect() to have any effect.""" self._will = False @@ -2343,8 +2529,7 @@ def loop_forever( self._handle_on_connect_fail() if not retry_first_connection: raise - self._easy_log( - MQTT_LOG_DEBUG, "Connection failed, retrying") + self._easy_log(MQTT_LOG_DEBUG, "Connection failed, retrying") self._reconnect_wait() else: break @@ -2357,17 +2542,24 @@ def loop_forever( # either called loop_forever() when in single threaded mode, or # in multi threaded mode when loop_stop() has been called and # so no other threads can access _out_packet or _messages. - if (self._thread_terminate is True + if ( + self._thread_terminate is True and len(self._out_packet) == 0 - and len(self._out_messages) == 0): + and len(self._out_messages) == 0 + ): rc = MQTTErrorCode.MQTT_ERR_NOMEM run = False def should_exit() -> bool: return ( - self._state in (_ConnectionState.MQTT_CS_DISCONNECTING, _ConnectionState.MQTT_CS_DISCONNECTED) or - run is False or # noqa: B023 (uses the run variable from the outer scope on purpose) - self._thread_terminate is True + self._state + in ( + _ConnectionState.MQTT_CS_DISCONNECTING, + _ConnectionState.MQTT_CS_DISCONNECTED, + ) + or run + is False # noqa: B023 (uses the run variable from the outer scope on purpose) + or self._thread_terminate is True ) if should_exit() or not self._reconnect_on_failure: @@ -2382,8 +2574,7 @@ def should_exit() -> bool: self.reconnect() except OSError: self._handle_on_connect_fail() - self._easy_log( - MQTT_LOG_DEBUG, "Connection failed, retrying") + self._easy_log(MQTT_LOG_DEBUG, "Connection failed, retrying") return rc @@ -2400,7 +2591,10 @@ def loop_start(self) -> MQTTErrorCode: self._sockpairR, self._sockpairW = _socketpair_compat() self._thread_terminate = False - self._thread = threading.Thread(target=self._thread_main, name=f"paho-mqtt-client-{self._client_id.decode()}") + self._thread = threading.Thread( + target=self._thread_main, + name=f"paho-mqtt-client-{self._client_id.decode()}", + ) self._thread.daemon = True self._thread.start() @@ -2462,6 +2656,7 @@ def log_callback(self) -> Callable[[CallbackOnLog], CallbackOnLog]: def decorator(func: CallbackOnLog) -> CallbackOnLog: self.on_log = func return func + return decorator @property @@ -2493,6 +2688,7 @@ def pre_connect_callback( def decorator(func: CallbackOnPreConnect) -> CallbackOnPreConnect: self.on_pre_connect = func return func + return decorator @property @@ -2559,6 +2755,7 @@ def connect_callback( def decorator(func: CallbackOnConnect) -> CallbackOnConnect: self.on_connect = func return func + return decorator @property @@ -2589,6 +2786,7 @@ def connect_fail_callback( def decorator(func: CallbackOnConnectFail) -> CallbackOnConnectFail: self.on_connect_fail = func return func + return decorator @property @@ -2639,6 +2837,7 @@ def subscribe_callback( def decorator(func: CallbackOnSubscribe) -> CallbackOnSubscribe: self.on_subscribe = func return func + return decorator @property @@ -2673,6 +2872,7 @@ def message_callback( def decorator(func: CallbackOnMessage) -> CallbackOnMessage: self.on_message = func return func + return decorator @property @@ -2729,6 +2929,7 @@ def publish_callback( def decorator(func: CallbackOnPublish) -> CallbackOnPublish: self.on_publish = func return func + return decorator @property @@ -2780,6 +2981,7 @@ def unsubscribe_callback( def decorator(func: CallbackOnUnsubscribe) -> CallbackOnUnsubscribe: self.on_unsubscribe = func return func + return decorator @property @@ -2833,6 +3035,7 @@ def disconnect_callback( def decorator(func: CallbackOnDisconnect) -> CallbackOnDisconnect: self.on_disconnect = func return func + return decorator @property @@ -2865,6 +3068,7 @@ def socket_open_callback( def decorator(func: CallbackOnSocket) -> CallbackOnSocket: self.on_socket_open = func return func + return decorator def _call_socket_open(self, sock: SocketLike) -> None: @@ -2878,7 +3082,8 @@ def _call_socket_open(self, sock: SocketLike) -> None: on_socket_open(self, self._userdata, sock) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_socket_open: %s', err) + MQTT_LOG_ERR, "Caught exception in on_socket_open: %s", err + ) if not self.suppress_exceptions: raise @@ -2912,6 +3117,7 @@ def socket_close_callback( def decorator(func: CallbackOnSocket) -> CallbackOnSocket: self.on_socket_close = func return func + return decorator def _call_socket_close(self, sock: SocketLike) -> None: @@ -2925,7 +3131,8 @@ def _call_socket_close(self, sock: SocketLike) -> None: on_socket_close(self, self._userdata, sock) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_socket_close: %s', err) + MQTT_LOG_ERR, "Caught exception in on_socket_close: %s", err + ) if not self.suppress_exceptions: raise @@ -2959,6 +3166,7 @@ def socket_register_write_callback( def decorator(func: CallbackOnSocket) -> CallbackOnSocket: self._on_socket_register_write = func return func + return decorator def _call_socket_register_write(self) -> None: @@ -2971,11 +3179,13 @@ def _call_socket_register_write(self) -> None: if on_socket_register_write: try: - on_socket_register_write( - self, self._userdata, self._sock) + on_socket_register_write(self, self._userdata, self._sock) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_socket_register_write: %s', err) + MQTT_LOG_ERR, + "Caught exception in on_socket_register_write: %s", + err, + ) if not self.suppress_exceptions: raise @@ -3001,9 +3211,7 @@ def on_socket_unregister_write( return self._on_socket_unregister_write @on_socket_unregister_write.setter - def on_socket_unregister_write( - self, func: CallbackOnSocket | None - ) -> None: + def on_socket_unregister_write(self, func: CallbackOnSocket | None) -> None: with self._callback_mutex: self._on_socket_unregister_write = func @@ -3015,11 +3223,10 @@ def decorator( ) -> CallbackOnSocket: self._on_socket_unregister_write = func return func + return decorator - def _call_socket_unregister_write( - self, sock: SocketLike | None = None - ) -> None: + def _call_socket_unregister_write(self, sock: SocketLike | None = None) -> None: """Call the socket_unregister_write callback with the writable socket""" sock = sock or self._sock if not sock or not self._registered_write: @@ -3034,7 +3241,10 @@ def _call_socket_unregister_write( on_socket_unregister_write(self, self._userdata, sock) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_socket_unregister_write: %s', err) + MQTT_LOG_ERR, + "Caught exception in on_socket_unregister_write: %s", + err, + ) if not self.suppress_exceptions: raise @@ -3073,6 +3283,7 @@ def topic_callback( def decorator(func: CallbackOnMessage) -> CallbackOnMessage: self.message_callback_add(sub, func) return func + return decorator def message_callback_remove(self, sub: str) -> None: @@ -3098,7 +3309,10 @@ def _loop_rc_handle( if rc: self._sock_close() - if self._state in (_ConnectionState.MQTT_CS_DISCONNECTING, _ConnectionState.MQTT_CS_DISCONNECTED): + if self._state in ( + _ConnectionState.MQTT_CS_DISCONNECTING, + _ConnectionState.MQTT_CS_DISCONNECTED, + ): self._state = _ConnectionState.MQTT_CS_DISCONNECTED rc = MQTTErrorCode.MQTT_ERR_SUCCESS @@ -3123,25 +3337,23 @@ def _packet_read(self) -> MQTTErrorCode: # fail due to longer length, so save current data and current position. # After all data is read, send to _mqtt_handle_packet() to deal with. # Finally, free the memory and reset everything to starting conditions. - if self._in_packet['command'] == 0: + if self._in_packet["command"] == 0: try: command = self._sock_recv(1) except BlockingIOError: return MQTTErrorCode.MQTT_ERR_AGAIN except TimeoutError as err: - self._easy_log( - MQTT_LOG_ERR, 'timeout on socket: %s', err) + self._easy_log(MQTT_LOG_ERR, "timeout on socket: %s", err) return MQTTErrorCode.MQTT_ERR_CONN_LOST except OSError as err: - self._easy_log( - MQTT_LOG_ERR, 'failed to receive on socket: %s', err) + self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) return MQTTErrorCode.MQTT_ERR_CONN_LOST else: if len(command) == 0: return MQTTErrorCode.MQTT_ERR_CONN_LOST - self._in_packet['command'] = command[0] + self._in_packet["command"] = command[0] - if self._in_packet['have_remaining'] == 0: + if self._in_packet["have_remaining"] == 0: # Read remaining # Algorithm for decoding taken from pseudo code at # http://publib.boulder.ibm.com/infocenter/wmbhelp/v6r0m0/topic/com.ibm.etools.mft.doc/ac10870_.htm @@ -3151,44 +3363,45 @@ def _packet_read(self) -> MQTTErrorCode: except BlockingIOError: return MQTTErrorCode.MQTT_ERR_AGAIN except OSError as err: - self._easy_log( - MQTT_LOG_ERR, 'failed to receive on socket: %s', err) + self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) return MQTTErrorCode.MQTT_ERR_CONN_LOST else: if len(byte) == 0: return MQTTErrorCode.MQTT_ERR_CONN_LOST byte_value = byte[0] - self._in_packet['remaining_count'].append(byte_value) + self._in_packet["remaining_count"].append(byte_value) # Max 4 bytes length for remaining length as defined by protocol. # Anything more likely means a broken/malicious client. - if len(self._in_packet['remaining_count']) > 4: + if len(self._in_packet["remaining_count"]) > 4: return MQTTErrorCode.MQTT_ERR_PROTOCOL - self._in_packet['remaining_length'] += ( - byte_value & 127) * self._in_packet['remaining_mult'] - self._in_packet['remaining_mult'] = self._in_packet['remaining_mult'] * 128 + self._in_packet["remaining_length"] += ( + byte_value & 127 + ) * self._in_packet["remaining_mult"] + self._in_packet["remaining_mult"] = ( + self._in_packet["remaining_mult"] * 128 + ) if (byte_value & 128) == 0: break - self._in_packet['have_remaining'] = 1 - self._in_packet['to_process'] = self._in_packet['remaining_length'] + self._in_packet["have_remaining"] = 1 + self._in_packet["to_process"] = self._in_packet["remaining_length"] - count = 100 # Don't get stuck in this loop if we have a huge message. - while self._in_packet['to_process'] > 0: + count = 100 # Don't get stuck in this loop if we have a huge message. + while self._in_packet["to_process"] > 0: try: - data = self._sock_recv(self._in_packet['to_process']) + data = self._sock_recv(self._in_packet["to_process"]) except BlockingIOError: return MQTTErrorCode.MQTT_ERR_AGAIN except OSError as err: - self._easy_log( - MQTT_LOG_ERR, 'failed to receive on socket: %s', err) + self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) return MQTTErrorCode.MQTT_ERR_CONN_LOST else: if len(data) == 0: return MQTTErrorCode.MQTT_ERR_CONN_LOST - self._in_packet['to_process'] -= len(data) - self._in_packet['packet'] += data + self._in_packet["to_process"] -= len(data) + self._in_packet["packet"] += data count -= 1 if count == 0: with self._msgtime_mutex: @@ -3196,7 +3409,7 @@ def _packet_read(self) -> MQTTErrorCode: return MQTTErrorCode.MQTT_ERR_AGAIN # All data for this packet is read. - self._in_packet['pos'] = 0 + self._in_packet["pos"] = 0 rc = self._packet_handle() # Free data and reset values @@ -3223,8 +3436,7 @@ def _packet_write(self) -> MQTTErrorCode: return MQTTErrorCode.MQTT_ERR_SUCCESS try: - write_length = self._sock_send( - packet['packet'][packet['pos']:]) + write_length = self._sock_send(packet["packet"][packet["pos"] :]) except (AttributeError, ValueError): self._out_packet.appendleft(packet) return MQTTErrorCode.MQTT_ERR_SUCCESS @@ -3233,28 +3445,37 @@ def _packet_write(self) -> MQTTErrorCode: return MQTTErrorCode.MQTT_ERR_AGAIN except OSError as err: self._out_packet.appendleft(packet) - self._easy_log( - MQTT_LOG_ERR, 'failed to receive on socket: %s', err) + self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) return MQTTErrorCode.MQTT_ERR_CONN_LOST if write_length > 0: - packet['to_process'] -= write_length - packet['pos'] += write_length + packet["to_process"] -= write_length + packet["pos"] += write_length - if packet['to_process'] == 0: - if (packet['command'] & 0xF0) == PUBLISH and packet['qos'] == 0: + if packet["to_process"] == 0: + if (packet["command"] & 0xF0) == PUBLISH and packet["qos"] == 0: with self._callback_mutex: on_publish = self.on_publish if on_publish: with self._in_callback_mutex: try: - if self._callback_api_version == CallbackAPIVersion.VERSION1: - on_publish = cast(CallbackOnPublish_v1, on_publish) + if ( + self._callback_api_version + == CallbackAPIVersion.VERSION1 + ): + on_publish = cast( + CallbackOnPublish_v1, on_publish + ) on_publish(self, self._userdata, packet["mid"]) - elif self._callback_api_version == CallbackAPIVersion.VERSION2: - on_publish = cast(CallbackOnPublish_v2, on_publish) + elif ( + self._callback_api_version + == CallbackAPIVersion.VERSION2 + ): + on_publish = cast( + CallbackOnPublish_v2, on_publish + ) on_publish( self, @@ -3264,10 +3485,15 @@ def _packet_write(self) -> MQTTErrorCode: Properties(PacketTypes.PUBACK), ) else: - raise RuntimeError("Unsupported callback API version") + raise RuntimeError( + "Unsupported callback API version" + ) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_publish: %s', err) + MQTT_LOG_ERR, + "Caught exception in on_publish: %s", + err, + ) if not self.suppress_exceptions: raise @@ -3275,9 +3501,9 @@ def _packet_write(self) -> MQTTErrorCode: # A packet could be produced by _handle_connack with qos=0 and no info # (around line 3645). Ignore the mypy check for now but I feel there is a bug # somewhere. - packet['info']._set_as_published() # type: ignore + packet["info"]._set_as_published() # type: ignore - if (packet['command'] & 0xF0) == DISCONNECT: + if (packet["command"] & 0xF0) == DISCONNECT: with self._msgtime_mutex: self._last_msg_out = time_func() @@ -3326,7 +3552,10 @@ def _check_keepalive(self) -> None: last_msg_out = self._last_msg_out last_msg_in = self._last_msg_in - if self._sock is not None and (now - last_msg_out >= self._keepalive or now - last_msg_in >= self._keepalive): + if self._sock is not None and ( + now - last_msg_out >= self._keepalive + or now - last_msg_in >= self._keepalive + ): if self._state == _ConnectionState.MQTT_CS_CONNECTED and self._ping_t == 0: try: self._send_pingreq() @@ -3343,7 +3572,10 @@ def _check_keepalive(self) -> None: else: self._sock_close() - if self._state in (_ConnectionState.MQTT_CS_DISCONNECTING, _ConnectionState.MQTT_CS_DISCONNECTED): + if self._state in ( + _ConnectionState.MQTT_CS_DISCONNECTING, + _ConnectionState.MQTT_CS_DISCONNECTED, + ): self._state = _ConnectionState.MQTT_CS_DISCONNECTED rc = MQTTErrorCode.MQTT_ERR_SUCCESS else: @@ -3363,20 +3595,23 @@ def _mid_generate(self) -> int: @staticmethod def _raise_for_invalid_topic(topic: bytes) -> None: - """ Check if the topic is a topic without wildcard and valid length. + """Check if the topic is a topic without wildcard and valid length. - Raise ValueError if the topic isn't valid. + Raise ValueError if the topic isn't valid. """ - if b'+' in topic or b'#' in topic: - raise ValueError('Publish topic cannot contain wildcards.') + if b"+" in topic or b"#" in topic: + raise ValueError("Publish topic cannot contain wildcards.") if len(topic) > 65535: - raise ValueError('Publish topic is too long.') + raise ValueError("Publish topic is too long.") @staticmethod def _filter_wildcard_len_check(sub: bytes) -> MQTTErrorCode: - if (len(sub) == 0 or len(sub) > 65535 - or any(b'+' in p or b'#' in p for p in sub.split(b'/') if len(p) > 1) - or b'#/' in sub): + if ( + len(sub) == 0 + or len(sub) > 65535 + or any(b"+" in p or b"#" in p for p in sub.split(b"/") if len(p) > 1) + or b"#/" in sub + ): return MQTTErrorCode.MQTT_ERR_INVAL else: return MQTTErrorCode.MQTT_ERR_SUCCESS @@ -3426,7 +3661,7 @@ def _send_publish( self, mid: int, topic: bytes, - payload: bytes|bytearray = b"", + payload: bytes | bytearray = b"", qos: int = 0, retain: bool = False, dup: bool = False, @@ -3435,9 +3670,9 @@ def _send_publish( ) -> MQTTErrorCode: # we assume that topic and payload are already properly encoded if not isinstance(topic, bytes): - raise TypeError('topic must be bytes, not str') + raise TypeError("topic must be bytes, not str") if payload and not isinstance(payload, (bytes, bytearray)): - raise TypeError('payload must be bytes if set') + raise TypeError("payload must be bytes if set") if self._sock is None: return MQTTErrorCode.MQTT_ERR_NO_CONN @@ -3454,26 +3689,46 @@ def _send_publish( self._easy_log( MQTT_LOG_DEBUG, "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s', properties=%s (NULL payload)", - dup, qos, retain, mid, topic, properties + dup, + qos, + retain, + mid, + topic, + properties, ) else: self._easy_log( MQTT_LOG_DEBUG, "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s' (NULL payload)", - dup, qos, retain, mid, topic + dup, + qos, + retain, + mid, + topic, ) else: if self._protocol == MQTTv5: self._easy_log( MQTT_LOG_DEBUG, "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s', properties=%s, ... (%d bytes)", - dup, qos, retain, mid, topic, properties, payloadlen + dup, + qos, + retain, + mid, + topic, + properties, + payloadlen, ) else: self._easy_log( MQTT_LOG_DEBUG, "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s', ... (%d bytes)", - dup, qos, retain, mid, topic, payloadlen + dup, + qos, + retain, + mid, + topic, + payloadlen, ) if qos > 0: @@ -3482,7 +3737,7 @@ def _send_publish( if self._protocol == MQTTv5: if properties is None: - packed_properties = b'\x00' + packed_properties = b"\x00" else: packed_properties = properties.pack() remaining_length += len(packed_properties) @@ -3515,13 +3770,13 @@ def _send_command_with_mid(self, command: int, mid: int, dup: int) -> MQTTErrorC command |= 0x8 remaining_length = 2 - packet = struct.pack('!BBH', command, remaining_length, mid) + packet = struct.pack("!BBH", command, remaining_length, mid) return self._packet_queue(command, packet, mid, 1) def _send_simple_command(self, command: int) -> MQTTErrorCode: # For DISCONNECT, PINGREQ and PINGRESP remaining_length = 0 - packet = struct.pack('!BB', command, remaining_length) + packet = struct.pack("!BB", command, remaining_length) return self._packet_queue(command, packet, 0, 0) def _send_connect(self, keepalive: int) -> MQTTErrorCode: @@ -3529,23 +3784,27 @@ def _send_connect(self, keepalive: int) -> MQTTErrorCode: # hard-coded UTF-8 encoded string protocol = b"MQTT" if proto_ver >= MQTTv311 else b"MQIsdp" - remaining_length = 2 + len(protocol) + 1 + \ - 1 + 2 + 2 + len(self._client_id) + remaining_length = 2 + len(protocol) + 1 + 1 + 2 + 2 + len(self._client_id) connect_flags = 0 if self._protocol == MQTTv5: if self._clean_start is True: connect_flags |= 0x02 - elif self._clean_start == MQTT_CLEAN_START_FIRST_ONLY and self._mqttv5_first_connect: + elif ( + self._clean_start == MQTT_CLEAN_START_FIRST_ONLY + and self._mqttv5_first_connect + ): connect_flags |= 0x02 elif self._clean_session: connect_flags |= 0x02 if self._will: - remaining_length += 2 + \ - len(self._will_topic) + 2 + len(self._will_payload) - connect_flags |= 0x04 | ((self._will_qos & 0x03) << 3) | ( - (self._will_retain & 0x01) << 5) + remaining_length += 2 + len(self._will_topic) + 2 + len(self._will_payload) + connect_flags |= ( + 0x04 + | ((self._will_qos & 0x03) << 3) + | ((self._will_retain & 0x01) << 5) + ) if self._username is not None: remaining_length += 2 + len(self._username) @@ -3556,13 +3815,13 @@ def _send_connect(self, keepalive: int) -> MQTTErrorCode: if self._protocol == MQTTv5: if self._connect_properties is None: - packed_connect_properties = b'\x00' + packed_connect_properties = b"\x00" else: packed_connect_properties = self._connect_properties.pack() remaining_length += len(packed_connect_properties) if self._will: if self._will_properties is None: - packed_will_properties = b'\x00' + packed_will_properties = b"\x00" else: packed_will_properties = self._will_properties.pack() remaining_length += len(packed_will_properties) @@ -3577,10 +3836,16 @@ def _send_connect(self, keepalive: int) -> MQTTErrorCode: proto_ver |= 0x80 self._pack_remaining_length(packet, remaining_length) - packet.extend(struct.pack( - f"!H{len(protocol)}sBBH", - len(protocol), protocol, proto_ver, connect_flags, keepalive, - )) + packet.extend( + struct.pack( + f"!H{len(protocol)}sBBH", + len(protocol), + protocol, + proto_ver, + connect_flags, + keepalive, + ) + ) if self._protocol == MQTTv5: packet += packed_connect_properties @@ -3612,7 +3877,7 @@ def _send_connect(self, keepalive: int) -> MQTTErrorCode: (connect_flags & 0x2) >> 1, keepalive, self._client_id, - self._connect_properties + self._connect_properties, ) else: self._easy_log( @@ -3625,7 +3890,7 @@ def _send_connect(self, keepalive: int) -> MQTTErrorCode: (connect_flags & 0x4) >> 2, (connect_flags & 0x2) >> 1, keepalive, - self._client_id + self._client_id, ) return self._packet_queue(command, packet, 0, 0) @@ -3635,10 +3900,12 @@ def _send_disconnect( properties: Properties | None = None, ) -> MQTTErrorCode: if self._protocol == MQTTv5: - self._easy_log(MQTT_LOG_DEBUG, "Sending DISCONNECT reasonCode=%s properties=%s", - reasoncode, - properties - ) + self._easy_log( + MQTT_LOG_DEBUG, + "Sending DISCONNECT reasonCode=%s properties=%s", + reasoncode, + properties, + ) else: self._easy_log(MQTT_LOG_DEBUG, "Sending DISCONNECT") @@ -3676,7 +3943,7 @@ def _send_subscribe( remaining_length = 2 if self._protocol == MQTTv5: if properties is None: - packed_subscribe_properties = b'\x00' + packed_subscribe_properties = b"\x00" else: packed_subscribe_properties = properties.pack() remaining_length += len(packed_subscribe_properties) @@ -3718,7 +3985,7 @@ def _send_unsubscribe( remaining_length = 2 if self._protocol == MQTTv5: if properties is None: - packed_unsubscribe_properties = b'\x00' + packed_unsubscribe_properties = b"\x00" else: packed_unsubscribe_properties = properties.pack() remaining_length += len(packed_unsubscribe_properties) @@ -3772,7 +4039,10 @@ def _messages_reconnect_reset_out(self) -> None: self._inflight_messages = 0 for m in self._out_messages.values(): m.timestamp = 0 - if self._max_inflight_messages == 0 or self._inflight_messages < self._max_inflight_messages: + if ( + self._max_inflight_messages == 0 + or self._inflight_messages < self._max_inflight_messages + ): if m.qos == 0: m.state = mqtt_ms_publish elif m.qos == 1: @@ -3853,7 +4123,7 @@ def _packet_queue( return MQTTErrorCode.MQTT_ERR_SUCCESS def _packet_handle(self) -> MQTTErrorCode: - cmd = self._in_packet['command'] & 0xF0 + cmd = self._in_packet["command"] & 0xF0 if cmd == PINGREQ: return self._handle_pingreq() elif cmd == PINGRESP: @@ -3884,14 +4154,14 @@ def _packet_handle(self) -> MQTTErrorCode: return MQTTErrorCode.MQTT_ERR_PROTOCOL def _handle_pingreq(self) -> MQTTErrorCode: - if self._in_packet['remaining_length'] != 0: + if self._in_packet["remaining_length"] != 0: return MQTTErrorCode.MQTT_ERR_PROTOCOL self._easy_log(MQTT_LOG_DEBUG, "Received PINGREQ") return self._send_pingresp() def _handle_pingresp(self) -> MQTTErrorCode: - if self._in_packet['remaining_length'] != 0: + if self._in_packet["remaining_length"] != 0: return MQTTErrorCode.MQTT_ERR_PROTOCOL # No longer waiting for a PINGRESP. @@ -3901,14 +4171,13 @@ def _handle_pingresp(self) -> MQTTErrorCode: def _handle_connack(self) -> MQTTErrorCode: if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 2: + if self._in_packet["remaining_length"] < 2: return MQTTErrorCode.MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: + elif self._in_packet["remaining_length"] != 2: return MQTTErrorCode.MQTT_ERR_PROTOCOL if self._protocol == MQTTv5: - (flags, result) = struct.unpack( - "!BB", self._in_packet['packet'][:2]) + (flags, result) = struct.unpack("!BB", self._in_packet["packet"][:2]) if result == 1: # This is probably a failure from a broker that doesn't support # MQTT v5. @@ -3917,9 +4186,9 @@ def _handle_connack(self) -> MQTTErrorCode: else: reason = ReasonCode(CONNACK >> 4, identifier=result) properties = Properties(CONNACK >> 4) - properties.unpack(self._in_packet['packet'][2:]) + properties.unpack(self._in_packet["packet"][2:]) else: - (flags, result) = struct.unpack("!BB", self._in_packet['packet']) + (flags, result) = struct.unpack("!BB", self._in_packet["packet"]) reason = convert_connack_rc_to_reason_code(result) properties = None if self._protocol == MQTTv311: @@ -3929,19 +4198,22 @@ def _handle_connack(self) -> MQTTErrorCode: self._easy_log( MQTT_LOG_DEBUG, "Received CONNACK (%s, %s), attempting downgrade to MQTT v3.1.", - flags, result + flags, + result, ) # Downgrade to MQTT v3.1 self._protocol = MQTTv31 return self.reconnect() - elif (result == CONNACK_REFUSED_IDENTIFIER_REJECTED - and self._client_id == b''): + elif ( + result == CONNACK_REFUSED_IDENTIFIER_REJECTED and self._client_id == b"" + ): if not self._reconnect_on_failure: return MQTT_ERR_PROTOCOL self._easy_log( MQTT_LOG_DEBUG, "Received CONNACK (%s, %s), attempting to use non-empty CID", - flags, result, + flags, + result, ) self._client_id = _base62(uuid.uuid4().int, padding=22).encode("utf8") return self.reconnect() @@ -3952,10 +4224,14 @@ def _handle_connack(self) -> MQTTErrorCode: if self._protocol == MQTTv5: self._easy_log( - MQTT_LOG_DEBUG, "Received CONNACK (%s, %s) properties=%s", flags, reason, properties) + MQTT_LOG_DEBUG, + "Received CONNACK (%s, %s) properties=%s", + flags, + reason, + properties, + ) else: - self._easy_log( - MQTT_LOG_DEBUG, "Received CONNACK (%s, %s)", flags, result) + self._easy_log(MQTT_LOG_DEBUG, "Received CONNACK (%s, %s)", flags, result) # it won't be the first successful connect any more self._mqttv5_first_connect = False @@ -3965,25 +4241,25 @@ def _handle_connack(self) -> MQTTErrorCode: if on_connect: flags_dict = {} - flags_dict['session present'] = flags & 0x01 + flags_dict["session present"] = flags & 0x01 with self._in_callback_mutex: try: if self._callback_api_version == CallbackAPIVersion.VERSION1: if self._protocol == MQTTv5: on_connect = cast(CallbackOnConnect_v1_mqtt5, on_connect) - on_connect(self, self._userdata, - flags_dict, reason, properties) + on_connect( + self, self._userdata, flags_dict, reason, properties + ) else: on_connect = cast(CallbackOnConnect_v1_mqtt3, on_connect) - on_connect( - self, self._userdata, flags_dict, result) + on_connect(self, self._userdata, flags_dict, result) elif self._callback_api_version == CallbackAPIVersion.VERSION2: on_connect = cast(CallbackOnConnect_v2, on_connect) connect_flags = ConnectFlags( - session_present=flags_dict['session present'] > 0 + session_present=flags_dict["session present"] > 0 ) if properties is None: @@ -4000,7 +4276,8 @@ def _handle_connack(self) -> MQTTErrorCode: raise RuntimeError("Unsupported callback API version") except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_connect: %s', err) + MQTT_LOG_ERR, "Caught exception in on_connect: %s", err + ) if not self.suppress_exceptions: raise @@ -4017,12 +4294,12 @@ def _handle_connack(self) -> MQTTErrorCode: with self._in_callback_mutex: # Don't call loop_write after _send_publish() rc = self._send_publish( m.mid, - m.topic.encode('utf-8'), + m.topic.encode("utf-8"), m.payload, m.qos, m.retain, m.dup, - properties=m.properties + properties=m.properties, ) if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc @@ -4033,12 +4310,12 @@ def _handle_connack(self) -> MQTTErrorCode: with self._in_callback_mutex: # Don't call loop_write after _send_publish() rc = self._send_publish( m.mid, - m.topic.encode('utf-8'), + m.topic.encode("utf-8"), m.payload, m.qos, m.retain, m.dup, - properties=m.properties + properties=m.properties, ) if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc @@ -4049,12 +4326,12 @@ def _handle_connack(self) -> MQTTErrorCode: with self._in_callback_mutex: # Don't call loop_write after _send_publish() rc = self._send_publish( m.mid, - m.topic.encode('utf-8'), + m.topic.encode("utf-8"), m.payload, m.qos, m.retain, m.dup, - properties=m.properties + properties=m.properties, ) if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc @@ -4076,17 +4353,15 @@ def _handle_connack(self) -> MQTTErrorCode: def _handle_disconnect(self) -> None: packet_type = DISCONNECT >> 4 reasonCode = properties = None - if self._in_packet['remaining_length'] > 2: + if self._in_packet["remaining_length"] > 2: reasonCode = ReasonCode(packet_type) - reasonCode.unpack(self._in_packet['packet']) - if self._in_packet['remaining_length'] > 3: + reasonCode.unpack(self._in_packet["packet"]) + if self._in_packet["remaining_length"] > 3: properties = Properties(packet_type) - props, props_len = properties.unpack( - self._in_packet['packet'][1:]) - self._easy_log(MQTT_LOG_DEBUG, "Received DISCONNECT %s %s", - reasonCode, - properties - ) + props, props_len = properties.unpack(self._in_packet["packet"][1:]) + self._easy_log( + MQTT_LOG_DEBUG, "Received DISCONNECT %s %s", reasonCode, properties + ) self._sock_close() self._do_on_disconnect( @@ -4099,12 +4374,14 @@ def _handle_disconnect(self) -> None: def _handle_suback(self) -> None: self._easy_log(MQTT_LOG_DEBUG, "Received SUBACK") pack_format = f"!H{len(self._in_packet['packet']) - 2}s" - (mid, packet) = struct.unpack(pack_format, self._in_packet['packet']) + (mid, packet) = struct.unpack(pack_format, self._in_packet["packet"]) if self._protocol == MQTTv5: properties = Properties(SUBACK >> 4) props, props_len = properties.unpack(packet) - reasoncodes = [ReasonCode(SUBACK >> 4, identifier=c) for c in packet[props_len:]] + reasoncodes = [ + ReasonCode(SUBACK >> 4, identifier=c) for c in packet[props_len:] + ] else: pack_format = f"!{'B' * len(packet)}" granted_qos = struct.unpack(pack_format, packet) @@ -4119,15 +4396,19 @@ def _handle_suback(self) -> None: try: if self._callback_api_version == CallbackAPIVersion.VERSION1: if self._protocol == MQTTv5: - on_subscribe = cast(CallbackOnSubscribe_v1_mqtt5, on_subscribe) + on_subscribe = cast( + CallbackOnSubscribe_v1_mqtt5, on_subscribe + ) on_subscribe( - self, self._userdata, mid, reasoncodes, properties) + self, self._userdata, mid, reasoncodes, properties + ) else: - on_subscribe = cast(CallbackOnSubscribe_v1_mqtt3, on_subscribe) + on_subscribe = cast( + CallbackOnSubscribe_v1_mqtt3, on_subscribe + ) - on_subscribe( - self, self._userdata, mid, granted_qos) + on_subscribe(self, self._userdata, mid, granted_qos) elif self._callback_api_version == CallbackAPIVersion.VERSION2: on_subscribe = cast(CallbackOnSubscribe_v2, on_subscribe) @@ -4142,19 +4423,20 @@ def _handle_suback(self) -> None: raise RuntimeError("Unsupported callback API version") except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_subscribe: %s', err) + MQTT_LOG_ERR, "Caught exception in on_subscribe: %s", err + ) if not self.suppress_exceptions: raise def _handle_publish(self) -> MQTTErrorCode: - header = self._in_packet['command'] + header = self._in_packet["command"] message = MQTTMessage() message.dup = ((header & 0x08) >> 3) != 0 message.qos = (header & 0x06) >> 1 message.retain = (header & 0x01) != 0 pack_format = f"!H{len(self._in_packet['packet']) - 2}s" - (slen, packet) = struct.unpack(pack_format, self._in_packet['packet']) + (slen, packet) = struct.unpack(pack_format, self._in_packet["packet"]) pack_format = f"!{slen}s{len(packet) - slen}s" (topic, packet) = struct.unpack(pack_format, packet) @@ -4166,7 +4448,7 @@ def _handle_publish(self) -> MQTTErrorCode: # representation of the topic for logging. When the user attempts to # access message.topic in the callback, an exception will be raised. try: - print_topic = topic.decode('utf-8') + print_topic = topic.decode("utf-8") except UnicodeDecodeError: print_topic = f"TOPIC WITH INVALID UTF-8: {topic!r}" @@ -4187,15 +4469,24 @@ def _handle_publish(self) -> MQTTErrorCode: self._easy_log( MQTT_LOG_DEBUG, "Received PUBLISH (d%d, q%d, r%d, m%d), '%s', properties=%s, ... (%d bytes)", - message.dup, message.qos, message.retain, message.mid, - print_topic, message.properties, len(message.payload) + message.dup, + message.qos, + message.retain, + message.mid, + print_topic, + message.properties, + len(message.payload), ) else: self._easy_log( MQTT_LOG_DEBUG, "Received PUBLISH (d%d, q%d, r%d, m%d), '%s', ... (%d bytes)", - message.dup, message.qos, message.retain, message.mid, - print_topic, len(message.payload) + message.dup, + message.qos, + message.retain, + message.mid, + print_topic, + len(message.payload), ) message.timestamp = time_func() @@ -4222,10 +4513,10 @@ def _handle_publish(self) -> MQTTErrorCode: def ack(self, mid: int, qos: int) -> MQTTErrorCode: """ - send an acknowledgement for a given message id (stored in :py:attr:`message.mid `). - only useful in QoS>=1 and ``manual_ack=True`` (option of `Client`) + send an acknowledgement for a given message id (stored in :py:attr:`message.mid `). + only useful in QoS>=1 and ``manual_ack=True`` (option of `Client`) """ - if self._manual_ack : + if self._manual_ack: if qos == 1: return self._send_puback(mid) elif qos == 2: @@ -4235,29 +4526,27 @@ def ack(self, mid: int, qos: int) -> MQTTErrorCode: def manual_ack_set(self, on: bool) -> None: """ - The paho library normally acknowledges messages as soon as they are delivered to the caller. - If manual_ack is turned on, then the caller MUST manually acknowledge every message once - application processing is complete using `ack()` + The paho library normally acknowledges messages as soon as they are delivered to the caller. + If manual_ack is turned on, then the caller MUST manually acknowledge every message once + application processing is complete using `ack()` """ self._manual_ack = on - def _handle_pubrel(self) -> MQTTErrorCode: if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 2: + if self._in_packet["remaining_length"] < 2: return MQTTErrorCode.MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: + elif self._in_packet["remaining_length"] != 2: return MQTTErrorCode.MQTT_ERR_PROTOCOL - mid, = struct.unpack("!H", self._in_packet['packet'][:2]) + (mid,) = struct.unpack("!H", self._in_packet["packet"][:2]) if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] > 2: + if self._in_packet["remaining_length"] > 2: reasonCode = ReasonCode(PUBREL >> 4) - reasonCode.unpack(self._in_packet['packet'][2:]) - if self._in_packet['remaining_length'] > 3: + reasonCode.unpack(self._in_packet["packet"][2:]) + if self._in_packet["remaining_length"] > 3: properties = Properties(PUBREL >> 4) - props, props_len = properties.unpack( - self._in_packet['packet'][3:]) + props, props_len = properties.unpack(self._in_packet["packet"][3:]) self._easy_log(MQTT_LOG_DEBUG, "Received PUBREL (Mid: %d)", mid) with self._in_message_mutex: @@ -4296,7 +4585,7 @@ def _update_inflight(self) -> MQTTErrorCode: m.state = mqtt_ms_wait_for_pubrec rc = self._send_publish( m.mid, - m.topic.encode('utf-8'), + m.topic.encode("utf-8"), m.payload, m.qos, m.retain, @@ -4311,20 +4600,19 @@ def _update_inflight(self) -> MQTTErrorCode: def _handle_pubrec(self) -> MQTTErrorCode: if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 2: + if self._in_packet["remaining_length"] < 2: return MQTTErrorCode.MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: + elif self._in_packet["remaining_length"] != 2: return MQTTErrorCode.MQTT_ERR_PROTOCOL - mid, = struct.unpack("!H", self._in_packet['packet'][:2]) + (mid,) = struct.unpack("!H", self._in_packet["packet"][:2]) if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] > 2: + if self._in_packet["remaining_length"] > 2: reasonCode = ReasonCode(PUBREC >> 4) - reasonCode.unpack(self._in_packet['packet'][2:]) - if self._in_packet['remaining_length'] > 3: + reasonCode.unpack(self._in_packet["packet"][2:]) + if self._in_packet["remaining_length"] > 3: properties = Properties(PUBREC >> 4) - props, props_len = properties.unpack( - self._in_packet['packet'][3:]) + props, props_len = properties.unpack(self._in_packet["packet"][3:]) self._easy_log(MQTT_LOG_DEBUG, "Received PUBREC (Mid: %d)", mid) with self._out_message_mutex: @@ -4338,19 +4626,18 @@ def _handle_pubrec(self) -> MQTTErrorCode: def _handle_unsuback(self) -> MQTTErrorCode: if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 4: + if self._in_packet["remaining_length"] < 4: return MQTTErrorCode.MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: + elif self._in_packet["remaining_length"] != 2: return MQTTErrorCode.MQTT_ERR_PROTOCOL - mid, = struct.unpack("!H", self._in_packet['packet'][:2]) + (mid,) = struct.unpack("!H", self._in_packet["packet"][:2]) if self._protocol == MQTTv5: - packet = self._in_packet['packet'][2:] + packet = self._in_packet["packet"][2:] properties = Properties(UNSUBACK >> 4) props, props_len = properties.unpack(packet) reasoncodes_list = [ - ReasonCode(UNSUBACK >> 4, identifier=c) - for c in packet[props_len:] + ReasonCode(UNSUBACK >> 4, identifier=c) for c in packet[props_len:] ] else: reasoncodes_list = [] @@ -4365,16 +4652,23 @@ def _handle_unsuback(self) -> MQTTErrorCode: try: if self._callback_api_version == CallbackAPIVersion.VERSION1: if self._protocol == MQTTv5: - on_unsubscribe = cast(CallbackOnUnsubscribe_v1_mqtt5, on_unsubscribe) + on_unsubscribe = cast( + CallbackOnUnsubscribe_v1_mqtt5, on_unsubscribe + ) - reasoncodes: ReasonCode | list[ReasonCode] = reasoncodes_list + reasoncodes: ReasonCode | list[ReasonCode] = ( + reasoncodes_list + ) if len(reasoncodes_list) == 1: reasoncodes = reasoncodes_list[0] on_unsubscribe( - self, self._userdata, mid, properties, reasoncodes) + self, self._userdata, mid, properties, reasoncodes + ) else: - on_unsubscribe = cast(CallbackOnUnsubscribe_v1_mqtt3, on_unsubscribe) + on_unsubscribe = cast( + CallbackOnUnsubscribe_v1_mqtt3, on_unsubscribe + ) on_unsubscribe(self, self._userdata, mid) elif self._callback_api_version == CallbackAPIVersion.VERSION2: @@ -4394,7 +4688,8 @@ def _handle_unsuback(self) -> MQTTErrorCode: raise RuntimeError("Unsupported callback API version") except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_unsubscribe: %s', err) + MQTT_LOG_ERR, "Caught exception in on_unsubscribe: %s", err + ) if not self.suppress_exceptions: raise @@ -4415,14 +4710,18 @@ def _do_on_disconnect( try: if self._callback_api_version == CallbackAPIVersion.VERSION1: if self._protocol == MQTTv5: - on_disconnect = cast(CallbackOnDisconnect_v1_mqtt5, on_disconnect) + on_disconnect = cast( + CallbackOnDisconnect_v1_mqtt5, on_disconnect + ) if packet_from_broker: on_disconnect(self, self._userdata, reason, properties) else: on_disconnect(self, self._userdata, v1_rc, None) else: - on_disconnect = cast(CallbackOnDisconnect_v1_mqtt3, on_disconnect) + on_disconnect = cast( + CallbackOnDisconnect_v1_mqtt3, on_disconnect + ) on_disconnect(self, self._userdata, v1_rc) elif self._callback_api_version == CallbackAPIVersion.VERSION2: @@ -4449,11 +4748,14 @@ def _do_on_disconnect( raise RuntimeError("Unsupported callback API version") except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_disconnect: %s', err) + MQTT_LOG_ERR, "Caught exception in on_disconnect: %s", err + ) if not self.suppress_exceptions: raise - def _do_on_publish(self, mid: int, reason_code: ReasonCode, properties: Properties) -> MQTTErrorCode: + def _do_on_publish( + self, mid: int, reason_code: ReasonCode, properties: Properties + ) -> MQTTErrorCode: with self._callback_mutex: on_publish = self.on_publish @@ -4478,7 +4780,8 @@ def _do_on_publish(self, mid: int, reason_code: ReasonCode, properties: Properti raise RuntimeError("Unsupported callback API version") except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_publish: %s', err) + MQTT_LOG_ERR, "Caught exception in on_publish: %s", err + ) if not self.suppress_exceptions: raise @@ -4493,25 +4796,24 @@ def _do_on_publish(self, mid: int, reason_code: ReasonCode, properties: Properti return MQTTErrorCode.MQTT_ERR_SUCCESS def _handle_pubackcomp( - self, cmd: Literal['PUBACK'] | Literal['PUBCOMP'] + self, cmd: Literal["PUBACK"] | Literal["PUBCOMP"] ) -> MQTTErrorCode: if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 2: + if self._in_packet["remaining_length"] < 2: return MQTTErrorCode.MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: + elif self._in_packet["remaining_length"] != 2: return MQTTErrorCode.MQTT_ERR_PROTOCOL packet_type_enum = PUBACK if cmd == "PUBACK" else PUBCOMP packet_type = packet_type_enum.value >> 4 - mid, = struct.unpack("!H", self._in_packet['packet'][:2]) + (mid,) = struct.unpack("!H", self._in_packet["packet"][:2]) reasonCode = ReasonCode(packet_type) properties = Properties(packet_type) if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] > 2: - reasonCode.unpack(self._in_packet['packet'][2:]) - if self._in_packet['remaining_length'] > 3: - props, props_len = properties.unpack( - self._in_packet['packet'][3:]) + if self._in_packet["remaining_length"] > 2: + reasonCode.unpack(self._in_packet["packet"][2:]) + if self._in_packet["remaining_length"] > 3: + props, props_len = properties.unpack(self._in_packet["packet"][3:]) self._easy_log(MQTT_LOG_DEBUG, "Received %s (Mid: %d)", cmd, mid) with self._out_message_mutex: @@ -4532,7 +4834,9 @@ def _handle_on_message(self, message: MQTTMessage) -> None: on_message_callbacks = [] with self._callback_mutex: if topic is not None: - on_message_callbacks = list(self._on_message_filtered.iter_match(message.topic)) + on_message_callbacks = list( + self._on_message_filtered.iter_match(message.topic) + ) if len(on_message_callbacks) == 0: on_message = self.on_message @@ -4546,9 +4850,9 @@ def _handle_on_message(self, message: MQTTMessage) -> None: except Exception as err: self._easy_log( MQTT_LOG_ERR, - 'Caught exception in user defined callback function %s: %s', + "Caught exception in user defined callback function %s: %s", callback.__name__, - err + err, ) if not self.suppress_exceptions: raise @@ -4559,11 +4863,11 @@ def _handle_on_message(self, message: MQTTMessage) -> None: on_message(self, self._userdata, message) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_message: %s', err) + MQTT_LOG_ERR, "Caught exception in on_message: %s", err + ) if not self.suppress_exceptions: raise - def _handle_on_connect_fail(self) -> None: with self._callback_mutex: on_connect_fail = self.on_connect_fail @@ -4574,7 +4878,8 @@ def _handle_on_connect_fail(self) -> None: on_connect_fail(self, self._userdata) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_connect_fail: %s', err) + MQTT_LOG_ERR, "Caught exception in on_connect_fail: %s", err + ) def _thread_main(self) -> None: try: @@ -4597,9 +4902,15 @@ def _reconnect_wait(self) -> None: target_time = now + self._reconnect_delay remaining = target_time - now - while (self._state not in (_ConnectionState.MQTT_CS_DISCONNECTING, _ConnectionState.MQTT_CS_DISCONNECTED) - and not self._thread_terminate - and remaining > 0): + while ( + self._state + not in ( + _ConnectionState.MQTT_CS_DISCONNECTING, + _ConnectionState.MQTT_CS_DISCONNECTED, + ) + and not self._thread_terminate + and remaining > 0 + ): time.sleep(min(remaining, 1)) remaining = target_time - time_func() @@ -4607,8 +4918,11 @@ def _reconnect_wait(self) -> None: @staticmethod def _proxy_is_valid(p) -> bool: # type: ignore[no-untyped-def] def check(t, a) -> bool: # type: ignore[no-untyped-def] - return (socks is not None and - t in {socks.HTTP, socks.SOCKS4, socks.SOCKS5} and a) + return ( + socks is not None + and t in {socks.HTTP, socks.SOCKS4, socks.SOCKS5} + and a + ) if isinstance(p, dict): return check(p.get("proxy_type"), p.get("proxy_addr")) @@ -4628,8 +4942,10 @@ def _get_proxy(self) -> dict[str, Any] | None: # Next, check for an mqtt_proxy environment variable as long as the host # we're trying to connect to isn't listed under the no_proxy environment # variable (matches built-in module urllib's behavior) - if not (hasattr(urllib.request, "proxy_bypass") and - urllib.request.proxy_bypass(self._host)): + if not ( + hasattr(urllib.request, "proxy_bypass") + and urllib.request.proxy_bypass(self._host) + ): env_proxies = urllib.request.getproxies() if "mqtt" in env_proxies: parts = urllib.parse.urlparse(env_proxies["mqtt"]) @@ -4637,14 +4953,14 @@ def _get_proxy(self) -> dict[str, Any] | None: proxy = { "proxy_type": socks.HTTP, "proxy_addr": parts.hostname, - "proxy_port": parts.port + "proxy_port": parts.port, } return proxy elif parts.scheme == "socks": proxy = { "proxy_type": socks.SOCKS5, "proxy_addr": parts.hostname, - "proxy_port": parts.port + "proxy_port": parts.port, } return proxy @@ -4652,8 +4968,14 @@ def _get_proxy(self) -> dict[str, Any] | None: # a default proxy socks_default = socks.get_default_proxy() if self._proxy_is_valid(socks_default): - proxy_keys = ("proxy_type", "proxy_addr", "proxy_port", - "proxy_rdns", "proxy_username", "proxy_password") + proxy_keys = ( + "proxy_type", + "proxy_addr", + "proxy_port", + "proxy_rdns", + "proxy_username", + "proxy_password", + ) return dict(zip(proxy_keys, socks_default)) # If we didn't find a proxy through any of the above methods, return @@ -4693,16 +5015,20 @@ def _create_socket_connection(self) -> _socket.socket: source = (self._bind_address, self._bind_port) if proxy: - return socks.create_connection(addr, timeout=self._connect_timeout, source_address=source, **proxy) + return socks.create_connection( + addr, timeout=self._connect_timeout, source_address=source, **proxy + ) else: - return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source) - + return socket.create_connection( + addr, timeout=self._connect_timeout, source_address=source + ) + def _ssl_wrap_socket(self, tcp_sock: _socket) -> _socket.socket: if self._ssl_context is None: raise ValueError( "Impossible condition. _ssl_context should never be None if _ssl is True" ) - + verify_host = not self._tls_insecure try: if isinstance(self._ssl_context, ssl.SSLContext): @@ -4717,7 +5043,7 @@ def _ssl_wrap_socket(self, tcp_sock: _socket) -> _socket.socket: conn = SSL.Connection(self._ssl_context, tcp_sock) conn.set_connect_state() if self._host: - conn.set_tlsext_host_name(self._host.encode('utf-8')) + conn.set_tlsext_host_name(self._host.encode("utf-8")) ssl_sock = conn else: raise ValueError("Unsupported SSL context type") @@ -4748,25 +5074,26 @@ def do_handshake_with_retries(ssl_sock, retries=35, delay=0.1): if HAS_OPENSSL and isinstance(ssl_sock, SSL.Connection): do_handshake_with_retries(ssl_sock) if verify_host: - if getattr(self._ssl_context, 'check_hostname', False): + if getattr(self._ssl_context, "check_hostname", False): verify_host = False _openssl_match_hostname(ssl_sock.get_peer_certificate(), self._host) else: ssl_sock.do_handshake() if verify_host: - if getattr(self._ssl_context, 'check_hostname', False): + if getattr(self._ssl_context, "check_hostname", False): verify_host = False ssl.match_hostname(ssl_sock.getpeercert(), self._host) return ssl_sock + class _WebsocketWrapper: OPCODE_CONTINUATION = 0x0 OPCODE_TEXT = 0x1 OPCODE_BINARY = 0x2 OPCODE_CONNCLOSE = 0x8 OPCODE_PING = 0x9 - OPCODE_PONG = 0xa + OPCODE_PONG = 0xA def __init__( self, @@ -4832,11 +5159,13 @@ def _do_handshake(self, extra_headers: WebSocketHeaders | None) -> None: elif callable(extra_headers): websocket_headers = extra_headers(websocket_headers) - header = "\r\n".join([ - f"GET {self._path} HTTP/1.1", - "\r\n".join(f"{i}: {j}" for i, j in websocket_headers.items()), - "\r\n", - ]).encode("utf8") + header = "\r\n".join( + [ + f"GET {self._path} HTTP/1.1", + "\r\n".join(f"{i}: {j}" for i, j in websocket_headers.items()), + "\r\n", + ] + ).encode("utf8") self._socket.send(header) @@ -4856,29 +5185,38 @@ def _do_handshake(self, extra_headers: WebSocketHeaders | None) -> None: if byte == b"\n": if len(self._readbuffer) > 2: # check upgrade - if b"connection" in str(self._readbuffer).lower().encode('utf-8'): - if b"upgrade" not in str(self._readbuffer).lower().encode('utf-8'): + if b"connection" in str(self._readbuffer).lower().encode("utf-8"): + if b"upgrade" not in str(self._readbuffer).lower().encode( + "utf-8" + ): raise WebsocketConnectionError( - "WebSocket handshake error, connection not upgraded") + "WebSocket handshake error, connection not upgraded" + ) else: has_upgrade = True # check key hash - if b"sec-websocket-accept" in str(self._readbuffer).lower().encode('utf-8'): + if b"sec-websocket-accept" in str(self._readbuffer).lower().encode( + "utf-8" + ): GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - server_hash_str = self._readbuffer.decode( - 'utf-8').split(": ", 1)[1] - server_hash = server_hash_str.strip().encode('utf-8') + server_hash_str = self._readbuffer.decode("utf-8").split( + ": ", 1 + )[1] + server_hash = server_hash_str.strip().encode("utf-8") - client_hash_key = sec_websocket_key.decode('utf-8') + GUID + client_hash_key = sec_websocket_key.decode("utf-8") + GUID # Use of SHA-1 is OK here; it's according to the Websocket spec. - client_hash_digest = hashlib.sha1(client_hash_key.encode('utf-8')) # noqa: S324 + client_hash_digest = hashlib.sha1( + client_hash_key.encode("utf-8") + ) # noqa: S324 client_hash = base64.b64encode(client_hash_digest.digest()) if server_hash != client_hash: raise WebsocketConnectionError( - "WebSocket handshake error, invalid secret key") + "WebSocket handshake error, invalid secret key" + ) else: has_secret = True else: @@ -4948,7 +5286,7 @@ def _buffered_read(self, length: int) -> bytearray: raise BlockingIOError self._readbuffer_head += length - return self._readbuffer[self._readbuffer_head - length:self._readbuffer_head] + return self._readbuffer[self._readbuffer_head - length : self._readbuffer_head] def _recv_impl(self, length: int) -> bytes: @@ -4965,22 +5303,22 @@ def _recv_impl(self, length: int) -> bytes: header1 = self._buffered_read(1) header2 = self._buffered_read(1) - opcode = (header1[0] & 0x0f) + opcode = header1[0] & 0x0F maskbit = (header2[0] & 0x80) == 0x80 - lengthbits = (header2[0] & 0x7f) + lengthbits = header2[0] & 0x7F payload_length = lengthbits mask_key = None # read length - if lengthbits == 0x7e: + if lengthbits == 0x7E: value = self._buffered_read(2) - payload_length, = struct.unpack("!H", value) + (payload_length,) = struct.unpack("!H", value) - elif lengthbits == 0x7f: + elif lengthbits == 0x7F: value = self._buffered_read(8) - payload_length, = struct.unpack("!Q", value) + (payload_length,) = struct.unpack("!Q", value) # read mask if maskbit: @@ -5013,33 +5351,36 @@ def _recv_impl(self, length: int) -> bytes: # respond to non-binary opcodes, their arrival is not guaranteed because of non-blocking sockets if opcode == _WebsocketWrapper.OPCODE_CONNCLOSE: frame = self._create_frame( - _WebsocketWrapper.OPCODE_CONNCLOSE, payload, 0) + _WebsocketWrapper.OPCODE_CONNCLOSE, payload, 0 + ) self._socket.send(frame) if opcode == _WebsocketWrapper.OPCODE_PING: frame = self._create_frame( - _WebsocketWrapper.OPCODE_PONG, payload, 0) + _WebsocketWrapper.OPCODE_PONG, payload, 0 + ) self._socket.send(frame) # This isn't *proper* handling of continuation frames, but given # that we only support binary frames, it is *probably* good enough. - if (opcode == _WebsocketWrapper.OPCODE_BINARY or opcode == _WebsocketWrapper.OPCODE_CONTINUATION) \ - and payload_length > 0: + if ( + opcode == _WebsocketWrapper.OPCODE_BINARY + or opcode == _WebsocketWrapper.OPCODE_CONTINUATION + ) and payload_length > 0: return result else: raise BlockingIOError except ConnectionError: self.connected = False - return b'' + return b"" def _send_impl(self, data: bytes) -> int: # if previous frame was sent successfully if len(self._sendbuffer) == 0: # create websocket frame - frame = self._create_frame( - _WebsocketWrapper.OPCODE_BINARY, bytearray(data)) + frame = self._create_frame(_WebsocketWrapper.OPCODE_BINARY, bytearray(data)) self._sendbuffer.extend(frame) self._requested_size = len(data)