From cef5be4927681a933052b51e6617b5042c9d9154 Mon Sep 17 00:00:00 2001 From: Pierre Fersing Date: Wed, 3 Jan 2024 22:13:49 +0100 Subject: [PATCH] Run pre-commit to apply format --- src/paho/mqtt/client.py | 805 ++++++++---------- src/paho/mqtt/matcher.py | 28 +- src/paho/mqtt/packettypes.py | 26 +- src/paho/mqtt/properties.py | 155 ++-- src/paho/mqtt/publish.py | 53 +- src/paho/mqtt/reasoncodes.py | 105 ++- src/paho/mqtt/subscribe.py | 98 ++- src/paho/mqtt/subscribeoptions.py | 31 +- tests/debug_helpers.py | 102 +-- tests/lib/clients/01-asyncio.py | 7 +- tests/lib/clients/01-unpwd-unicode-set.py | 1 - tests/lib/clients/03-publish-fill-inflight.py | 2 + tests/lib/conftest.py | 13 +- tests/lib/test_01_reconnect_on_failure.py | 3 +- tests/lib/test_01_unpwd_empty_password_set.py | 3 +- tests/lib/test_01_unpwd_empty_set.py | 3 +- tests/lib/test_01_unpwd_set.py | 3 +- tests/lib/test_01_will_set.py | 4 +- tests/lib/test_01_will_unpwd_set.py | 8 +- tests/lib/test_03_publish_b2c_qos1.py | 3 +- .../test_03_publish_c2b_qos1_disconnect.py | 10 +- .../test_03_publish_c2b_qos2_disconnect.py | 10 +- tests/lib/test_03_publish_fill_inflight.py | 18 +- tests/lib/test_03_publish_helper_qos0.py | 7 +- tests/lib/test_03_publish_helper_qos0_v5.py | 8 +- .../test_03_publish_helper_qos1_disconnect.py | 12 +- tests/lib/test_04_retain_qos0.py | 3 +- tests/mqtt5_props.py | 19 +- tests/paho_test.py | 153 ++-- tests/test_client.py | 67 +- tests/test_matcher.py | 49 +- tests/test_mqttv5.py | 355 +++----- tests/test_websocket_integration.py | 168 ++-- tests/test_websockets.py | 25 +- tests/testsupport/broker.py | 6 +- 35 files changed, 1104 insertions(+), 1259 deletions(-) diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index f6336beb..09b273e1 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -65,7 +65,7 @@ HAVE_DNS = False -if platform.system() == 'Windows': +if platform.system() == "Windows": EAGAIN = errno.WSAEWOULDBLOCK else: EAGAIN = errno.EAGAIN @@ -231,7 +231,7 @@ def base62(num, base=string.digits + string.ascii_letters, padding=1): num, rest = divmod(num, 62) digits.append(base[rest]) digits.extend(base[0] for _ in range(len(digits), padding)) - return ''.join(reversed(digits)) + return "".join(reversed(digits)) def topic_matches_sub(sub, topic): @@ -253,15 +253,13 @@ def topic_matches_sub(sub, topic): def _socketpair_compat(): """TCP/IP socketpair including Windows support""" - listensock = socket.socket( - socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) + listensock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) listensock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) listensock.bind(("127.0.0.1", 0)) listensock.listen(1) iface, port = listensock.getsockname() - sock1 = socket.socket( - socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) + sock1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) sock1.setblocking(0) try: sock1.connect(("127.0.0.1", port)) @@ -279,7 +277,7 @@ class MQTTMessageInfo: message has been published, and/or wait until it is published. """ - __slots__ = 'mid', '_published', '_condition', 'rc', '_iterpos' + __slots__ = "mid", "_published", "_condition", "rc", "_iterpos" def __init__(self, mid): self.mid = mid @@ -334,14 +332,15 @@ def wait_for_publish(self, timeout=None): reason. """ if self.rc == MQTT_ERR_QUEUE_SIZE: - raise ValueError('Message is not queued due to ERR_QUEUE_SIZE') + raise ValueError("Message is not queued due to ERR_QUEUE_SIZE") elif self.rc == MQTT_ERR_AGAIN: pass elif self.rc > 0: - raise RuntimeError(f'Message publish failed: {error_string(self.rc)}') + raise RuntimeError(f"Message publish failed: {error_string(self.rc)}") timeout_time = None if timeout is None else time_func() + timeout - timeout_tenth = None if timeout is None else timeout / 10. + timeout_tenth = None if timeout is None else timeout / 10.0 + def timed_out(): return False if timeout is None else time_func() > timeout_time @@ -353,18 +352,18 @@ def is_published(self): """Returns True if the message associated with this object has been published, else returns False.""" if self.rc == MQTT_ERR_QUEUE_SIZE: - raise ValueError('Message is not queued due to ERR_QUEUE_SIZE') + raise ValueError("Message is not queued due to ERR_QUEUE_SIZE") elif self.rc == MQTT_ERR_AGAIN: pass elif self.rc > 0: - raise RuntimeError(f'Message publish failed: {error_string(self.rc)}') + raise RuntimeError(f"Message publish failed: {error_string(self.rc)}") with self._condition: return self._published class MQTTMessage: - """ This is a class that describes an incoming or outgoing message. It is + """This is a class that describes an incoming or outgoing message. It is passed to the on_message callback as the message parameter. Members: @@ -377,7 +376,7 @@ class MQTTMessage: properties: Properties class. In MQTT v5.0, the properties associated with the message. """ - __slots__ = 'timestamp', 'state', 'dup', 'mid', '_topic', 'payload', 'qos', 'retain', 'info', 'properties' + __slots__ = "timestamp", "state", "dup", "mid", "_topic", "payload", "qos", "retain", "info", "properties" def __init__(self, mid=0, topic=b""): self.timestamp = 0 @@ -402,7 +401,7 @@ def __ne__(self, other): @property def topic(self): - return self._topic.decode('utf-8') + return self._topic.decode("utf-8") @topic.setter def topic(self, value): @@ -469,9 +468,7 @@ def on_connect(client, userdata, flags, rc): on_socket_register_write, on_socket_unregister_write """ - def __init__(self, client_id="", clean_session=None, userdata=None, - protocol=MQTTv311, transport="tcp", reconnect_on_failure=True, - manual_ack=False ): + def __init__(self, client_id="", clean_session=None, userdata=None, protocol=MQTTv311, transport="tcp", reconnect_on_failure=True, manual_ack=False): """client_id is the unique client id string used when connecting to the broker. If client_id is zero length or None, then the behaviour is defined by which protocol version is in use. If using MQTT v3.1.1, then @@ -513,28 +510,29 @@ def __init__(self, client_id="", clean_session=None, userdata=None, """ - if transport.lower() not in ('websockets', 'tcp'): - raise ValueError( - f'transport must be "websockets" or "tcp", not {transport}') + if transport.lower() not in ("websockets", "tcp"): + raise ValueError(f'transport must be "websockets" or "tcp", not {transport}') self._manual_ack = manual_ack self._transport = transport.lower() self._protocol = protocol self._userdata = userdata self._sock = None - self._sockpairR, self._sockpairW = (None, None,) + self._sockpairR, self._sockpairW = ( + None, + None, + ) self._keepalive = 60 self._connect_timeout = 5.0 self._client_mode = MQTT_CLIENT if protocol == MQTTv5: if clean_session is not None: - raise ValueError('Clean session is not used for MQTT 5.0') + raise ValueError("Clean session is not used for MQTT 5.0") else: if clean_session is None: clean_session = True if not clean_session and (client_id == "" or client_id is None): - raise ValueError( - 'A client id must be provided if clean session is False.') + raise ValueError("A client id must be provided if clean session is False.") self._clean_session = clean_session # [MQTT-3.1.3-4] Client Id must be UTF-8 encoded string. @@ -546,7 +544,7 @@ def __init__(self, client_id="", clean_session=None, userdata=None, else: self._client_id = client_id if isinstance(self._client_id, str): - self._client_id = self._client_id.encode('utf-8') + self._client_id = self._client_id.encode("utf-8") self._username = None self._password = None @@ -558,7 +556,8 @@ def __init__(self, client_id="", clean_session=None, userdata=None, "remaining_length": 0, "packet": bytearray(b""), "to_process": 0, - "pos": 0} + "pos": 0, + } self._out_packet = collections.deque() self._last_msg_in = time_func() self._last_msg_out = time_func() @@ -620,7 +619,7 @@ def __init__(self, client_id="", clean_session=None, userdata=None, self._websocket_extra_headers = None # for clean_start == MQTT_CLEAN_START_FIRST_ONLY self._mqttv5_first_connect = True - self.suppress_exceptions = False # For callbacks + self.suppress_exceptions = False # For callbacks def __del__(self): self._reset_sockets() @@ -634,8 +633,7 @@ def _sock_recv(self, bufsize): self._call_socket_register_write() raise BlockingIOError() from err except AttributeError as err: - self._easy_log( - MQTT_LOG_DEBUG, "socket was None: %s", err) + self._easy_log(MQTT_LOG_DEBUG, "socket was None: %s", err) raise ConnectionError() from err def _sock_send(self, buf): @@ -681,7 +679,7 @@ def reinitialise(self, client_id="", clean_session=True, userdata=None): self.__init__(client_id, clean_session, userdata) def ws_set_options(self, path="/mqtt", headers=None): - """ Set the path and headers for a websocket connection + """Set the path and headers for a websocket connection path is a string starting with / which should be the endpoint of the mqtt connection on the remote server @@ -697,8 +695,7 @@ def ws_set_options(self, path="/mqtt", headers=None): if isinstance(headers, dict) or callable(headers): self._websocket_extra_headers = headers else: - raise ValueError( - "'headers' option to ws_set_options has to be either a dictionary or callable") + raise ValueError("'headers' option to ws_set_options has to be either a dictionary or callable") def tls_set_context(self, context=None): """Configure network encryption and authentication context. Enables SSL/TLS support. @@ -708,7 +705,7 @@ def tls_set_context(self, context=None): Must be called before connect() or connect_async().""" if self._ssl_context is not None: - raise ValueError('SSL/TLS has already been configured.') + raise ValueError("SSL/TLS has already been configured.") if context is None: context = ssl.create_default_context() @@ -717,7 +714,7 @@ def tls_set_context(self, context=None): self._ssl_context = context # Ensure _tls_insecure is consistent with check_hostname attribute - if hasattr(context, 'check_hostname'): + if hasattr(context, "check_hostname"): self._tls_insecure = not context.check_hostname def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tls_version=None, ciphers=None, keyfile_password=None): @@ -761,15 +758,14 @@ def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tl Must be called before connect() or connect_async().""" if ssl is None: - raise ValueError('This platform has no SSL/TLS.') + raise ValueError("This platform has no SSL/TLS.") - if not hasattr(ssl, 'SSLContext'): + if not hasattr(ssl, "SSLContext"): # Require Python version that has SSL context support in standard library - raise ValueError( - 'Python 2.7.9 and 3.2 are the minimum supported versions for TLS.') + raise ValueError("Python 2.7.9 and 3.2 are the minimum supported versions for TLS.") - if ca_certs is None and not hasattr(ssl.SSLContext, 'load_default_certs'): - raise ValueError('ca_certs must not be None.') + if ca_certs is None and not hasattr(ssl.SSLContext, "load_default_certs"): + raise ValueError("ca_certs must not be None.") # Create SSLContext object if tls_version is None: @@ -789,7 +785,7 @@ def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tl if certfile is not None: context.load_cert_chain(certfile, keyfile, keyfile_password) - if cert_reqs == ssl.CERT_NONE and hasattr(context, 'check_hostname'): + if cert_reqs == ssl.CERT_NONE and hasattr(context, "check_hostname"): context.check_hostname = False context.verify_mode = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs @@ -825,13 +821,12 @@ def tls_insecure_set(self, value): tls_set_context().""" if self._ssl_context is None: - raise ValueError( - 'Must configure SSL context before using tls_insecure_set.') + raise ValueError("Must configure SSL context before using tls_insecure_set.") self._tls_insecure = value # Ensure check_hostname is consistent with _tls_insecure attribute - if hasattr(self._ssl_context, 'check_hostname'): + if hasattr(self._ssl_context, "check_hostname"): # Rely on SSLContext to check host name # If verify_mode is CERT_NONE then the host name will never be checked self._ssl_context.check_hostname = not value @@ -862,7 +857,7 @@ def proxy_set(self, **proxy_args): self._proxy = proxy_args def enable_logger(self, logger=None): - """ Enables a logger to send log messages to """ + """Enables a logger to send log messages to""" if logger is None: if self._logger is not None: # Do not replace existing logger @@ -873,8 +868,7 @@ def enable_logger(self, logger=None): def disable_logger(self): self._logger = None - def connect(self, host, port=1883, keepalive=60, bind_address="", bind_port=0, - clean_start=MQTT_CLEAN_START_FIRST_ONLY, properties=None): + def connect(self, host, port=1883, keepalive=60, bind_address="", bind_port=0, clean_start=MQTT_CLEAN_START_FIRST_ONLY, properties=None): """Connect to a remote broker. This is a blocking call that establishes the underlying connection and transmits a CONNECT packet. @@ -901,12 +895,10 @@ def connect(self, host, port=1883, keepalive=60, bind_address="", bind_port=0, if properties: raise ValueError("Properties only apply to MQTT V5") - self.connect_async(host, port, keepalive, - bind_address, bind_port, clean_start, properties) + self.connect_async(host, port, keepalive, bind_address, bind_port, clean_start, properties) return self.reconnect() - def connect_srv(self, domain=None, keepalive=60, bind_address="", - clean_start=MQTT_CLEAN_START_FIRST_ONLY, properties=None): + def connect_srv(self, domain=None, keepalive=60, bind_address="", clean_start=MQTT_CLEAN_START_FIRST_ONLY, properties=None): """Connect to a remote broker. domain is the DNS domain to search for SRV records; if None, @@ -915,23 +907,21 @@ def connect_srv(self, domain=None, keepalive=60, bind_address="", """ if HAVE_DNS is False: - raise ValueError( - 'No DNS resolver library found, try "pip install dnspython" or "pip3 install dnspython3".') + raise ValueError('No DNS resolver library found, try "pip install dnspython" or "pip3 install dnspython3".') if domain is None: domain = socket.getfqdn() - domain = domain[domain.find('.') + 1:] + domain = domain[domain.find(".") + 1 :] try: - rr = f'_mqtt._tcp.{domain}' + rr = f"_mqtt._tcp.{domain}" if self._ssl: # IANA specifies secure-mqtt (not mqtts) for port 8883 - rr = f'_secure-mqtt._tcp.{domain}' + rr = f"_secure-mqtt._tcp.{domain}" answers = [] for answer in dns.resolver.query(rr, dns.rdatatype.SRV): addr = answer.target.to_text()[:-1] - answers.append( - (addr, answer.port, answer.priority, answer.weight)) + answers.append((addr, answer.port, answer.priority, answer.weight)) except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer, dns.resolver.NoNameservers) as err: raise ValueError(f"No answer/NXDOMAIN for SRV in {domain}") from err @@ -946,8 +936,7 @@ def connect_srv(self, domain=None, keepalive=60, bind_address="", raise ValueError("No SRV hosts responded") - def connect_async(self, host, port=1883, keepalive=60, bind_address="", bind_port=0, - clean_start=MQTT_CLEAN_START_FIRST_ONLY, properties=None): + def connect_async(self, host, port=1883, keepalive=60, bind_address="", bind_port=0, clean_start=MQTT_CLEAN_START_FIRST_ONLY, properties=None): """Connect to a remote broker asynchronously. This is a non-blocking connect call that can be used with loop_start() to provide very quick start. @@ -967,13 +956,13 @@ def connect_async(self, host, port=1883, keepalive=60, bind_address="", bind_por MQTT connect packet. Use the Properties class. """ if host is None or len(host) == 0: - raise ValueError('Invalid host.') + raise ValueError("Invalid host.") if port <= 0: - raise ValueError('Invalid port number.') + raise ValueError("Invalid port number.") if keepalive < 0: - raise ValueError('Keepalive must be >=0.') + raise ValueError("Keepalive must be >=0.") if bind_port < 0: - raise ValueError('Invalid bind port number.') + raise ValueError("Invalid bind port number.") self._host = host self._port = port @@ -984,14 +973,13 @@ def connect_async(self, host, port=1883, keepalive=60, bind_address="", bind_por self._connect_properties = properties self._state = mqtt_cs_connect_async - def reconnect_delay_set(self, min_delay=1, max_delay=120): - """ Configure the exponential reconnect delay + """Configure the exponential reconnect delay - When connection is lost, wait initially min_delay seconds and - double this time every attempt. The wait is capped at max_delay. - Once the client is fully connected (e.g. not only TCP socket, but - received a success CONNACK), the wait timer is reset to min_delay. + When connection is lost, wait initially min_delay seconds and + double this time every attempt. The wait is capped at max_delay. + Once the client is fully connected (e.g. not only TCP socket, but + received a success CONNACK), the wait timer is reset to min_delay. """ with self._reconnect_delay_mutex: self._reconnect_min_delay = min_delay @@ -1002,9 +990,9 @@ def reconnect(self): """Reconnect the client after a disconnect. Can only be called after connect()/connect_async().""" if len(self._host) == 0: - raise ValueError('Invalid host.') + raise ValueError("Invalid host.") if self._port <= 0: - raise ValueError('Invalid port number.') + raise ValueError("Invalid port number.") self._in_packet = { "command": 0, @@ -1014,7 +1002,8 @@ def reconnect(self): "remaining_length": 0, "packet": bytearray(b""), "to_process": 0, - "pos": 0} + "pos": 0, + } self._out_packet = collections.deque() @@ -1037,8 +1026,7 @@ def reconnect(self): try: on_pre_connect(self, self._userdata) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_pre_connect: %s', err) + self._easy_log(MQTT_LOG_ERR, "Caught exception in on_pre_connect: %s", err) if not self.suppress_exceptions: raise @@ -1066,8 +1054,7 @@ def reconnect(self): ) else: # If SSL context has already checked hostname, then don't need to do it again - if (hasattr(self._ssl_context, 'check_hostname') and - self._ssl_context.check_hostname): + if hasattr(self._ssl_context, "check_hostname") and self._ssl_context.check_hostname: verify_host = False sock.settimeout(self._keepalive) @@ -1078,8 +1065,7 @@ def reconnect(self): if self._transport == "websockets": sock.settimeout(self._keepalive) - sock = WebsocketWrapper(sock, self._host, self._port, self._ssl, - self._websocket_path, self._websocket_extra_headers) + sock = WebsocketWrapper(sock, self._host, self._port, self._ssl, self._websocket_path, self._websocket_extra_headers) self._sock = sock self._sock.setblocking(0) @@ -1122,7 +1108,7 @@ def loop(self, timeout=1.0, max_packets=1): def _loop(self, timeout=1.0): if timeout < 0.0: - raise ValueError('Invalid timeout.') + raise ValueError("Invalid timeout.") try: packet = self._out_packet.popleft() @@ -1133,7 +1119,7 @@ def _loop(self, timeout=1.0): # used to check if there are any bytes left in the (SSL) socket pending_bytes = 0 - if hasattr(self._sock, 'pending'): + if hasattr(self._sock, "pending"): pending_bytes = self._sock.pending() # if bytes are pending do not wait in select @@ -1226,37 +1212,35 @@ def publish(self, topic, payload=None, qos=0, retain=False, properties=None): the length of the payload is greater than 268435455 bytes.""" if self._protocol != MQTTv5: if topic is None or len(topic) == 0: - raise ValueError('Invalid topic.') + raise ValueError("Invalid topic.") - topic = topic.encode('utf-8') + topic = topic.encode("utf-8") if self._topic_wildcard_len_check(topic) != MQTT_ERR_SUCCESS: - raise ValueError('Publish topic cannot contain wildcards.') + raise ValueError("Publish topic cannot contain wildcards.") if qos < 0 or qos > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") if isinstance(payload, str): - local_payload = payload.encode('utf-8') + local_payload = payload.encode("utf-8") elif isinstance(payload, (bytes, bytearray)): local_payload = payload elif isinstance(payload, (int, float)): - local_payload = str(payload).encode('ascii') + local_payload = str(payload).encode("ascii") elif payload is None: - local_payload = b'' + local_payload = b"" else: - raise TypeError( - 'payload must be a string, bytearray, int, float or None.') + raise TypeError("payload must be a string, bytearray, int, float or None.") if len(local_payload) > 268435455: - raise ValueError('Payload too large.') + raise ValueError("Payload too large.") local_mid = self._mid_generate() if qos == 0: info = MQTTMessageInfo(local_mid) - rc = self._send_publish( - local_mid, topic, local_payload, qos, retain, False, info, properties) + rc = self._send_publish(local_mid, topic, local_payload, qos, retain, False, info, properties) info.rc = rc return info else: @@ -1285,8 +1269,7 @@ def publish(self, topic, payload=None, qos=0, retain=False, properties=None): elif qos == 2: message.state = mqtt_ms_wait_for_pubrec - rc = self._send_publish(message.mid, topic, message.payload, message.qos, message.retain, - message.dup, message.info, message.properties) + rc = self._send_publish(message.mid, topic, message.payload, message.qos, message.retain, message.dup, message.info, message.properties) # remove from inflight messages so it will be send after a connection is made if rc is MQTT_ERR_NO_CONN: @@ -1314,10 +1297,10 @@ def username_pw_set(self, username, password=None): """ # [MQTT-3.1.3-11] User name must be UTF-8 encoded string - self._username = None if username is None else username.encode('utf-8') + self._username = None if username is None else username.encode("utf-8") self._password = password if isinstance(self._password, str): - self._password = self._password.encode('utf-8') + self._password = self._password.encode("utf-8") def enable_bridge_mode(self): """Sets the client in a bridge mode instead of client mode. @@ -1444,29 +1427,26 @@ def subscribe(self, topic, qos=0, options=None, properties=None): if self._protocol == MQTTv5: topic, options = topic if not isinstance(options, SubscribeOptions): - raise ValueError( - 'Subscribe options must be instance of SubscribeOptions class.') + raise ValueError("Subscribe options must be instance of SubscribeOptions class.") else: topic, qos = topic if isinstance(topic, (bytes, str)): if qos < 0 or qos > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") if self._protocol == MQTTv5: if options is None: # if no options are provided, use the QoS passed instead options = SubscribeOptions(qos=qos) elif qos != 0: - raise ValueError( - 'Subscribe options and qos parameters cannot be combined.') + raise ValueError("Subscribe options and qos parameters cannot be combined.") if not isinstance(options, SubscribeOptions): - raise ValueError( - 'Subscribe options must be instance of SubscribeOptions class.') - topic_qos_list = [(topic.encode('utf-8'), options)] + raise ValueError("Subscribe options must be instance of SubscribeOptions class.") + topic_qos_list = [(topic.encode("utf-8"), options)] else: if topic is None or len(topic) == 0: - raise ValueError('Invalid topic.') - topic_qos_list = [(topic.encode('utf-8'), qos)] + raise ValueError("Invalid topic.") + topic_qos_list = [(topic.encode("utf-8"), qos)] elif isinstance(topic, list): topic_qos_list = [] if self._protocol == MQTTv5: @@ -1474,22 +1454,22 @@ def subscribe(self, topic, qos=0, options=None, properties=None): if not isinstance(o, SubscribeOptions): # then the second value should be QoS if o < 0 or o > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") o = SubscribeOptions(qos=o) - topic_qos_list.append((t.encode('utf-8'), o)) + topic_qos_list.append((t.encode("utf-8"), o)) else: for t, q in topic: if q < 0 or q > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") if t is None or len(t) == 0 or not isinstance(t, (bytes, str)): - raise ValueError('Invalid topic.') - topic_qos_list.append((t.encode('utf-8'), q)) + raise ValueError("Invalid topic.") + topic_qos_list.append((t.encode("utf-8"), q)) if topic_qos_list is None: raise ValueError("No topic specified, or incorrect topic type.") if any(self._filter_wildcard_len_check(topic) != MQTT_ERR_SUCCESS for topic, _ in topic_qos_list): - raise ValueError('Invalid subscription filter.') + raise ValueError("Invalid subscription filter.") if self._sock is None: return (MQTT_ERR_NO_CONN, None) @@ -1516,17 +1496,17 @@ def unsubscribe(self, topic, properties=None): """ topic_list = None if topic is None: - raise ValueError('Invalid topic.') + raise ValueError("Invalid topic.") if isinstance(topic, (bytes, str)): if len(topic) == 0: - raise ValueError('Invalid topic.') - topic_list = [topic.encode('utf-8')] + raise ValueError("Invalid topic.") + topic_list = [topic.encode("utf-8")] elif isinstance(topic, list): topic_list = [] for t in topic: if len(t) == 0 or not isinstance(t, (bytes, str)): - raise ValueError('Invalid topic.') - topic_list.append(t.encode('utf-8')) + raise ValueError("Invalid topic.") + topic_list.append(t.encode("utf-8")) if topic_list is None: raise ValueError("No topic specified, or incorrect topic type.") @@ -1630,16 +1610,16 @@ def max_inflight_messages_set(self, inflight): """Set the maximum number of messages with QoS>0 that can be part way through their network flow at once. Defaults to 20.""" if inflight < 0: - raise ValueError('Invalid inflight.') + raise ValueError("Invalid inflight.") self._max_inflight_messages = inflight def max_queued_messages_set(self, queue_size): """Set the maximum number of messages in the outgoing message queue. 0 means unlimited.""" if queue_size < 0: - raise ValueError('Invalid queue size.') + raise ValueError("Invalid queue size.") if not isinstance(queue_size, int): - raise ValueError('Invalid type of queue size.') + raise ValueError("Invalid type of queue size.") self._max_queued_messages = queue_size return self @@ -1676,35 +1656,33 @@ def will_set(self, topic, payload=None, qos=0, retain=False, properties=None): zero string length. """ if topic is None or len(topic) == 0: - raise ValueError('Invalid topic.') + raise ValueError("Invalid topic.") if qos < 0 or qos > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") if properties and not isinstance(properties, Properties): - raise ValueError( - "The properties argument must be an instance of the Properties class.") + raise ValueError("The properties argument must be an instance of the Properties class.") if isinstance(payload, str): - self._will_payload = payload.encode('utf-8') + self._will_payload = payload.encode("utf-8") elif isinstance(payload, (bytes, bytearray)): self._will_payload = payload elif isinstance(payload, (int, float)): - self._will_payload = str(payload).encode('ascii') + self._will_payload = str(payload).encode("ascii") elif payload is None: self._will_payload = b"" else: - raise TypeError( - 'payload must be a string, bytearray, int, float or None.') + raise TypeError("payload must be a string, bytearray, int, float or None.") self._will = True - self._will_topic = topic.encode('utf-8') + self._will_topic = topic.encode("utf-8") self._will_qos = qos self._will_retain = retain self._will_properties = properties def will_clear(self): - """ Removes a will that was previously configured with will_set(). + """Removes a will that was previously configured with will_set(). Must be called before connect() to have any effect.""" self._will = False @@ -1749,8 +1727,7 @@ def loop_forever(self, timeout=1.0, max_packets=1, retry_first_connection=False) self._handle_on_connect_fail() if not retry_first_connection: raise - self._easy_log( - MQTT_LOG_DEBUG, "Connection failed, retrying") + self._easy_log(MQTT_LOG_DEBUG, "Connection failed, retrying") self._reconnect_wait() else: break @@ -1763,17 +1740,15 @@ def loop_forever(self, timeout=1.0, max_packets=1, retry_first_connection=False) # either called loop_forever() when in single threaded mode, or # in multi threaded mode when loop_stop() has been called and # so no other threads can access _out_packet or _messages. - if (self._thread_terminate is True - and len(self._out_packet) == 0 - and len(self._out_messages) == 0): + if self._thread_terminate is True and len(self._out_packet) == 0 and len(self._out_messages) == 0: rc = 1 run = False def should_exit(): return ( - self._state == mqtt_cs_disconnecting or - run is False or # noqa: B023 (uses the run variable from the outer scope on purpose) - self._thread_terminate is True + self._state == mqtt_cs_disconnecting + or run is False # noqa: B023 (uses the run variable from the outer scope on purpose) + or self._thread_terminate is True ) if should_exit() or not self._reconnect_on_failure: @@ -1788,8 +1763,7 @@ def should_exit(): self.reconnect() except (OSError, WebsocketConnectionError): self._handle_on_connect_fail() - self._easy_log( - MQTT_LOG_DEBUG, "Connection failed, retrying") + self._easy_log(MQTT_LOG_DEBUG, "Connection failed, retrying") return rc @@ -1830,7 +1804,7 @@ def on_log(self): @on_log.setter def on_log(self, func): - """ Define the logging callback implementation. + """Define the logging callback implementation. Expected signature is: log_callback(client, userdata, level, buf) @@ -1851,6 +1825,7 @@ def log_callback(self): def decorator(func): self.on_log = func return func + return decorator @property @@ -1861,7 +1836,7 @@ def on_pre_connect(self): @on_pre_connect.setter def on_pre_connect(self, func): - """ Define the pre_connect callback implementation. + """Define the pre_connect callback implementation. Expected signature: connect_callback(client, userdata) @@ -1880,6 +1855,7 @@ def pre_connect_callback(self): def decorator(func): self.on_pre_connect = func return func + return decorator @property @@ -1890,7 +1866,7 @@ def on_connect(self): @on_connect.setter def on_connect(self, func): - """ Define the connect callback implementation. + """Define the connect callback implementation. Expected signature for MQTT v3.1 and v3.1.1 is: connect_callback(client, userdata, flags, rc) @@ -1936,6 +1912,7 @@ def connect_callback(self): def decorator(func): self.on_connect = func return func + return decorator @property @@ -1946,7 +1923,7 @@ def on_connect_fail(self): @on_connect_fail.setter def on_connect_fail(self, func): - """ Define the connection failure callback implementation + """Define the connection failure callback implementation Expected signature is: on_connect_fail(client, userdata) @@ -1965,6 +1942,7 @@ def connect_fail_callback(self): def decorator(func): self.on_connect_fail = func return func + return decorator @property @@ -1975,7 +1953,7 @@ def on_subscribe(self): @on_subscribe.setter def on_subscribe(self, func): - """ Define the subscribe callback implementation. + """Define the subscribe callback implementation. Expected signature for MQTT v3.1.1 and v3.1 is: subscribe_callback(client, userdata, mid, granted_qos) @@ -2004,6 +1982,7 @@ def subscribe_callback(self): def decorator(func): self.on_subscribe = func return func + return decorator @property @@ -2018,7 +1997,7 @@ def on_message(self): @on_message.setter def on_message(self, func): - """ Define the message received callback implementation. + """Define the message received callback implementation. Expected signature is: on_message_callback(client, userdata, message) @@ -2039,6 +2018,7 @@ def message_callback(self): def decorator(func): self.on_message = func return func + return decorator @property @@ -2055,7 +2035,7 @@ def on_publish(self): @on_publish.setter def on_publish(self, func): - """ Define the published message callback implementation. + """Define the published message callback implementation. Expected signature is: on_publish_callback(client, userdata, mid) @@ -2076,6 +2056,7 @@ def publish_callback(self): def decorator(func): self.on_publish = func return func + return decorator @property @@ -2086,7 +2067,7 @@ def on_unsubscribe(self): @on_unsubscribe.setter def on_unsubscribe(self, func): - """ Define the unsubscribe callback implementation. + """Define the unsubscribe callback implementation. Expected signature for MQTT v3.1.1 and v3.1 is: unsubscribe_callback(client, userdata, mid) @@ -2113,17 +2094,17 @@ def unsubscribe_callback(self): def decorator(func): self.on_unsubscribe = func return func + return decorator @property def on_disconnect(self): - """If implemented, called when the client disconnects from the broker. - """ + """If implemented, called when the client disconnects from the broker.""" return self._on_disconnect @on_disconnect.setter def on_disconnect(self, func): - """ Define the disconnect callback implementation. + """Define the disconnect callback implementation. Expected signature for MQTT v3.1.1 and v3.1 is: disconnect_callback(client, userdata, rc) @@ -2150,6 +2131,7 @@ def disconnect_callback(self): def decorator(func): self.on_disconnect = func return func + return decorator @property @@ -2180,6 +2162,7 @@ def socket_open_callback(self): def decorator(func): self.on_socket_open = func return func + return decorator def _call_socket_open(self): @@ -2192,8 +2175,7 @@ def _call_socket_open(self): try: on_socket_open(self, self._userdata, self._sock) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_socket_open: %s', err) + self._easy_log(MQTT_LOG_ERR, "Caught exception in on_socket_open: %s", err) if not self.suppress_exceptions: raise @@ -2225,6 +2207,7 @@ def socket_close_callback(self): def decorator(func): self.on_socket_close = func return func + return decorator def _call_socket_close(self, sock): @@ -2237,8 +2220,7 @@ def _call_socket_close(self, sock): try: on_socket_close(self, self._userdata, sock) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_socket_close: %s', err) + self._easy_log(MQTT_LOG_ERR, "Caught exception in on_socket_close: %s", err) if not self.suppress_exceptions: raise @@ -2270,6 +2252,7 @@ def socket_register_write_callback(self): def decorator(func): self._on_socket_register_write = func return func + return decorator def _call_socket_register_write(self): @@ -2282,11 +2265,9 @@ def _call_socket_register_write(self): if on_socket_register_write: try: - on_socket_register_write( - self, self._userdata, self._sock) + on_socket_register_write(self, self._userdata, self._sock) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_socket_register_write: %s', err) + self._easy_log(MQTT_LOG_ERR, "Caught exception in on_socket_register_write: %s", err) if not self.suppress_exceptions: raise @@ -2318,6 +2299,7 @@ def socket_unregister_write_callback(self): def decorator(func): self._on_socket_unregister_write = func return func + return decorator def _call_socket_unregister_write(self, sock=None): @@ -2334,8 +2316,7 @@ def _call_socket_unregister_write(self, sock=None): try: on_socket_unregister_write(self, self._userdata, sock) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_socket_unregister_write: %s', err) + self._easy_log(MQTT_LOG_ERR, "Caught exception in on_socket_unregister_write: %s", err) if not self.suppress_exceptions: raise @@ -2360,6 +2341,7 @@ def topic_callback(self, sub): def decorator(func): self.message_callback_add(sub, func) return func + return decorator def message_callback_remove(self, sub): @@ -2403,26 +2385,24 @@ def _packet_read(self): # fail due to longer length, so save current data and current position. # After all data is read, send to _mqtt_handle_packet() to deal with. # Finally, free the memory and reset everything to starting conditions. - if self._in_packet['command'] == 0: + if self._in_packet["command"] == 0: try: command = self._sock_recv(1) except BlockingIOError: return MQTT_ERR_AGAIN except ConnectionError as err: - self._easy_log( - MQTT_LOG_ERR, 'failed to receive on socket: %s', err) + self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) return MQTT_ERR_CONN_LOST except TimeoutError as err: - self._easy_log( - MQTT_LOG_ERR, 'timeout on socket: %s', err) + self._easy_log(MQTT_LOG_ERR, "timeout on socket: %s", err) return MQTT_ERR_CONN_LOST else: if len(command) == 0: return MQTT_ERR_CONN_LOST - command, = struct.unpack("!B", command) - self._in_packet['command'] = command + (command,) = struct.unpack("!B", command) + self._in_packet["command"] = command - if self._in_packet['have_remaining'] == 0: + if self._in_packet["have_remaining"] == 0: # Read remaining # Algorithm for decoding taken from pseudo code at # http://publib.boulder.ibm.com/infocenter/wmbhelp/v6r0m0/topic/com.ibm.etools.mft.doc/ac10870_.htm @@ -2432,44 +2412,41 @@ def _packet_read(self): except BlockingIOError: return MQTT_ERR_AGAIN except ConnectionError as err: - self._easy_log( - MQTT_LOG_ERR, 'failed to receive on socket: %s', err) + self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) return MQTT_ERR_CONN_LOST else: if len(byte) == 0: return MQTT_ERR_CONN_LOST - byte, = struct.unpack("!B", byte) - self._in_packet['remaining_count'].append(byte) + (byte,) = struct.unpack("!B", byte) + self._in_packet["remaining_count"].append(byte) # Max 4 bytes length for remaining length as defined by protocol. # Anything more likely means a broken/malicious client. - if len(self._in_packet['remaining_count']) > 4: + if len(self._in_packet["remaining_count"]) > 4: return MQTT_ERR_PROTOCOL - self._in_packet['remaining_length'] += ( - byte & 127) * self._in_packet['remaining_mult'] - self._in_packet['remaining_mult'] = self._in_packet['remaining_mult'] * 128 + self._in_packet["remaining_length"] += (byte & 127) * self._in_packet["remaining_mult"] + self._in_packet["remaining_mult"] = self._in_packet["remaining_mult"] * 128 if (byte & 128) == 0: break - self._in_packet['have_remaining'] = 1 - self._in_packet['to_process'] = self._in_packet['remaining_length'] + self._in_packet["have_remaining"] = 1 + self._in_packet["to_process"] = self._in_packet["remaining_length"] - count = 100 # Don't get stuck in this loop if we have a huge message. - while self._in_packet['to_process'] > 0: + count = 100 # Don't get stuck in this loop if we have a huge message. + while self._in_packet["to_process"] > 0: try: - data = self._sock_recv(self._in_packet['to_process']) + data = self._sock_recv(self._in_packet["to_process"]) except BlockingIOError: return MQTT_ERR_AGAIN except ConnectionError as err: - self._easy_log( - MQTT_LOG_ERR, 'failed to receive on socket: %s', err) + self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) return MQTT_ERR_CONN_LOST else: if len(data) == 0: return MQTT_ERR_CONN_LOST - self._in_packet['to_process'] -= len(data) - self._in_packet['packet'] += data + self._in_packet["to_process"] -= len(data) + self._in_packet["packet"] += data count -= 1 if count == 0: with self._msgtime_mutex: @@ -2477,19 +2454,20 @@ def _packet_read(self): return MQTT_ERR_AGAIN # All data for this packet is read. - self._in_packet['pos'] = 0 + self._in_packet["pos"] = 0 rc = self._packet_handle() # Free data and reset values self._in_packet = { - 'command': 0, - 'have_remaining': 0, - 'remaining_count': [], - 'remaining_mult': 1, - 'remaining_length': 0, - 'packet': bytearray(b""), - 'to_process': 0, - 'pos': 0} + "command": 0, + "have_remaining": 0, + "remaining_count": [], + "remaining_mult": 1, + "remaining_length": 0, + "packet": bytearray(b""), + "to_process": 0, + "pos": 0, + } with self._msgtime_mutex: self._last_msg_in = time_func() @@ -2503,8 +2481,7 @@ def _packet_write(self): return MQTT_ERR_SUCCESS try: - write_length = self._sock_send( - packet['packet'][packet['pos']:]) + write_length = self._sock_send(packet["packet"][packet["pos"] :]) except (AttributeError, ValueError): self._out_packet.appendleft(packet) return MQTT_ERR_SUCCESS @@ -2513,33 +2490,30 @@ def _packet_write(self): return MQTT_ERR_AGAIN except ConnectionError as err: self._out_packet.appendleft(packet) - self._easy_log( - MQTT_LOG_ERR, 'failed to receive on socket: %s', err) + self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) return MQTT_ERR_CONN_LOST if write_length > 0: - packet['to_process'] -= write_length - packet['pos'] += write_length + packet["to_process"] -= write_length + packet["pos"] += write_length - if packet['to_process'] == 0: - if (packet['command'] & 0xF0) == PUBLISH and packet['qos'] == 0: + if packet["to_process"] == 0: + if (packet["command"] & 0xF0) == PUBLISH and packet["qos"] == 0: with self._callback_mutex: on_publish = self.on_publish if on_publish: with self._in_callback_mutex: try: - on_publish( - self, self._userdata, packet['mid']) + on_publish(self, self._userdata, packet["mid"]) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_publish: %s', err) + self._easy_log(MQTT_LOG_ERR, "Caught exception in on_publish: %s", err) if not self.suppress_exceptions: raise - packet['info']._set_as_published() + packet["info"]._set_as_published() - if (packet['command'] & 0xF0) == DISCONNECT: + if (packet["command"] & 0xF0) == DISCONNECT: with self._msgtime_mutex: self._last_msg_out = time_func() @@ -2613,16 +2587,14 @@ def _topic_wildcard_len_check(topic): # Search for + or # in a topic. Return MQTT_ERR_INVAL if found. # Also returns MQTT_ERR_INVAL if the topic string is too long. # Returns MQTT_ERR_SUCCESS if everything is fine. - if b'+' in topic or b'#' in topic or len(topic) > 65535: + if b"+" in topic or b"#" in topic or len(topic) > 65535: return MQTT_ERR_INVAL else: return MQTT_ERR_SUCCESS @staticmethod def _filter_wildcard_len_check(sub): - if (len(sub) == 0 or len(sub) > 65535 - or any(b'+' in p or b'#' in p for p in sub.split(b'/') if len(p) > 1) - or b'#/' in sub): + if len(sub) == 0 or len(sub) > 65535 or any(b"+" in p or b"#" in p for p in sub.split(b"/") if len(p) > 1) or b"#/" in sub: return MQTT_ERR_INVAL else: return MQTT_ERR_SUCCESS @@ -2663,16 +2635,16 @@ def _pack_remaining_length(self, packet, remaining_length): def _pack_str16(self, packet, data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") packet.extend(struct.pack("!H", len(data))) packet.extend(data) - def _send_publish(self, mid, topic, payload=b'', qos=0, retain=False, dup=False, info=None, properties=None): + def _send_publish(self, mid, topic, payload=b"", qos=0, retain=False, dup=False, info=None, properties=None): # we assume that topic and payload are already properly encoded if not isinstance(topic, bytes): - raise TypeError('topic must be bytes, not str') + raise TypeError("topic must be bytes, not str") if payload and not isinstance(payload, bytes): - raise TypeError('payload must be bytes if set') + raise TypeError("payload must be bytes if set") if self._sock is None: return MQTT_ERR_NO_CONN @@ -2686,30 +2658,16 @@ def _send_publish(self, mid, topic, payload=b'', qos=0, retain=False, dup=False, if payloadlen == 0: if self._protocol == MQTTv5: - self._easy_log( - MQTT_LOG_DEBUG, - "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s', properties=%s (NULL payload)", - dup, qos, retain, mid, topic, properties - ) + self._easy_log(MQTT_LOG_DEBUG, "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s', properties=%s (NULL payload)", dup, qos, retain, mid, topic, properties) else: - self._easy_log( - MQTT_LOG_DEBUG, - "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s' (NULL payload)", - dup, qos, retain, mid, topic - ) + self._easy_log(MQTT_LOG_DEBUG, "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s' (NULL payload)", dup, qos, retain, mid, topic) else: if self._protocol == MQTTv5: self._easy_log( - MQTT_LOG_DEBUG, - "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s', properties=%s, ... (%d bytes)", - dup, qos, retain, mid, topic, properties, payloadlen + MQTT_LOG_DEBUG, "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s', properties=%s, ... (%d bytes)", dup, qos, retain, mid, topic, properties, payloadlen ) else: - self._easy_log( - MQTT_LOG_DEBUG, - "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s', ... (%d bytes)", - dup, qos, retain, mid, topic, payloadlen - ) + self._easy_log(MQTT_LOG_DEBUG, "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s', ... (%d bytes)", dup, qos, retain, mid, topic, payloadlen) if qos > 0: # For message id @@ -2717,7 +2675,7 @@ def _send_publish(self, mid, topic, payload=b'', qos=0, retain=False, dup=False, if self._protocol == MQTTv5: if properties is None: - packed_properties = b'\x00' + packed_properties = b"\x00" else: packed_properties = properties.pack() remaining_length += len(packed_properties) @@ -2750,13 +2708,13 @@ def _send_command_with_mid(self, command, mid, dup): command |= 0x8 remaining_length = 2 - packet = struct.pack('!BBH', command, remaining_length, mid) + packet = struct.pack("!BBH", command, remaining_length, mid) return self._packet_queue(command, packet, mid, 1) def _send_simple_command(self, command): # For DISCONNECT, PINGREQ and PINGRESP remaining_length = 0 - packet = struct.pack('!BB', command, remaining_length) + packet = struct.pack("!BB", command, remaining_length) return self._packet_queue(command, packet, 0, 0) def _send_connect(self, keepalive): @@ -2764,8 +2722,7 @@ def _send_connect(self, keepalive): # hard-coded UTF-8 encoded string protocol = b"MQTT" if proto_ver >= MQTTv311 else b"MQIsdp" - remaining_length = 2 + len(protocol) + 1 + \ - 1 + 2 + 2 + len(self._client_id) + remaining_length = 2 + len(protocol) + 1 + 1 + 2 + 2 + len(self._client_id) connect_flags = 0 if self._protocol == MQTTv5: @@ -2777,10 +2734,8 @@ def _send_connect(self, keepalive): connect_flags |= 0x02 if self._will: - remaining_length += 2 + \ - len(self._will_topic) + 2 + len(self._will_payload) - connect_flags |= 0x04 | ((self._will_qos & 0x03) << 3) | ( - (self._will_retain & 0x01) << 5) + remaining_length += 2 + len(self._will_topic) + 2 + len(self._will_payload) + connect_flags |= 0x04 | ((self._will_qos & 0x03) << 3) | ((self._will_retain & 0x01) << 5) if self._username is not None: remaining_length += 2 + len(self._username) @@ -2791,13 +2746,13 @@ def _send_connect(self, keepalive): if self._protocol == MQTTv5: if self._connect_properties is None: - packed_connect_properties = b'\x00' + packed_connect_properties = b"\x00" else: packed_connect_properties = self._connect_properties.pack() remaining_length += len(packed_connect_properties) if self._will: if self._will_properties is None: - packed_will_properties = b'\x00' + packed_will_properties = b"\x00" else: packed_will_properties = self._will_properties.pack() remaining_length += len(packed_will_properties) @@ -2812,10 +2767,16 @@ def _send_connect(self, keepalive): proto_ver |= 0x80 self._pack_remaining_length(packet, remaining_length) - packet.extend(struct.pack( - f"!H{len(protocol)}sBBH", - len(protocol), protocol, proto_ver, connect_flags, keepalive, - )) + packet.extend( + struct.pack( + f"!H{len(protocol)}sBBH", + len(protocol), + protocol, + proto_ver, + connect_flags, + keepalive, + ) + ) if self._protocol == MQTTv5: packet += packed_connect_properties @@ -2847,7 +2808,7 @@ def _send_connect(self, keepalive): (connect_flags & 0x2) >> 1, keepalive, self._client_id, - self._connect_properties + self._connect_properties, ) else: self._easy_log( @@ -2860,16 +2821,13 @@ def _send_connect(self, keepalive): (connect_flags & 0x4) >> 2, (connect_flags & 0x2) >> 1, keepalive, - self._client_id + self._client_id, ) return self._packet_queue(command, packet, 0, 0) def _send_disconnect(self, reasoncode=None, properties=None): if self._protocol == MQTTv5: - self._easy_log(MQTT_LOG_DEBUG, "Sending DISCONNECT reasonCode=%s properties=%s", - reasoncode, - properties - ) + self._easy_log(MQTT_LOG_DEBUG, "Sending DISCONNECT reasonCode=%s properties=%s", reasoncode, properties) else: self._easy_log(MQTT_LOG_DEBUG, "Sending DISCONNECT") @@ -2902,7 +2860,7 @@ def _send_subscribe(self, dup, topics, properties=None): remaining_length = 2 if self._protocol == MQTTv5: if properties is None: - packed_subscribe_properties = b'\x00' + packed_subscribe_properties = b"\x00" else: packed_subscribe_properties = properties.pack() remaining_length += len(packed_subscribe_properties) @@ -2939,7 +2897,7 @@ def _send_unsubscribe(self, dup, topics, properties=None): remaining_length = 2 if self._protocol == MQTTv5: if properties is None: - packed_unsubscribe_properties = b'\x00' + packed_unsubscribe_properties = b"\x00" else: packed_unsubscribe_properties = properties.pack() remaining_length += len(packed_unsubscribe_properties) @@ -3035,14 +2993,7 @@ def _messages_reconnect_reset(self): self._messages_reconnect_reset_in() def _packet_queue(self, command, packet, mid, qos, info=None): - mpkt = { - 'command': command, - 'mid': mid, - 'qos': qos, - 'pos': 0, - 'to_process': len(packet), - 'packet': packet, - 'info': info} + mpkt = {"command": command, "mid": mid, "qos": qos, "pos": 0, "to_process": len(packet), "packet": packet, "info": info} self._out_packet.append(mpkt) @@ -3066,7 +3017,7 @@ def _packet_queue(self, command, packet, mid, qos, info=None): return MQTT_ERR_SUCCESS def _packet_handle(self): - cmd = self._in_packet['command'] & 0xF0 + cmd = self._in_packet["command"] & 0xF0 if cmd == PINGREQ: return self._handle_pingreq() elif cmd == PINGRESP: @@ -3095,14 +3046,14 @@ def _packet_handle(self): return MQTT_ERR_PROTOCOL def _handle_pingreq(self): - if self._in_packet['remaining_length'] != 0: + if self._in_packet["remaining_length"] != 0: return MQTT_ERR_PROTOCOL self._easy_log(MQTT_LOG_DEBUG, "Received PINGREQ") return self._send_pingresp() def _handle_pingresp(self): - if self._in_packet['remaining_length'] != 0: + if self._in_packet["remaining_length"] != 0: return MQTT_ERR_PROTOCOL # No longer waiting for a PINGRESP. @@ -3112,45 +3063,40 @@ def _handle_pingresp(self): def _handle_connack(self): if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 2: + if self._in_packet["remaining_length"] < 2: return MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: + elif self._in_packet["remaining_length"] != 2: return MQTT_ERR_PROTOCOL if self._protocol == MQTTv5: - (flags, result) = struct.unpack( - "!BB", self._in_packet['packet'][:2]) + (flags, result) = struct.unpack("!BB", self._in_packet["packet"][:2]) if result == 1: # This is probably a failure from a broker that doesn't support # MQTT v5. - reason = 132 # Unsupported protocol version + reason = 132 # Unsupported protocol version properties = None else: reason = ReasonCodes(CONNACK >> 4, identifier=result) properties = Properties(CONNACK >> 4) - properties.unpack(self._in_packet['packet'][2:]) + properties.unpack(self._in_packet["packet"][2:]) else: - (flags, result) = struct.unpack("!BB", self._in_packet['packet']) + (flags, result) = struct.unpack("!BB", self._in_packet["packet"]) if self._protocol == MQTTv311: if result == CONNACK_REFUSED_PROTOCOL_VERSION: if not self._reconnect_on_failure: return MQTT_ERR_PROTOCOL - self._easy_log( - MQTT_LOG_DEBUG, - "Received CONNACK (%s, %s), attempting downgrade to MQTT v3.1.", - flags, result - ) + self._easy_log(MQTT_LOG_DEBUG, "Received CONNACK (%s, %s), attempting downgrade to MQTT v3.1.", flags, result) # Downgrade to MQTT v3.1 self._protocol = MQTTv31 return self.reconnect() - elif (result == CONNACK_REFUSED_IDENTIFIER_REJECTED - and self._client_id == b''): + elif result == CONNACK_REFUSED_IDENTIFIER_REJECTED and self._client_id == b"": if not self._reconnect_on_failure: return MQTT_ERR_PROTOCOL self._easy_log( MQTT_LOG_DEBUG, "Received CONNACK (%s, %s), attempting to use non-empty CID", - flags, result, + flags, + result, ) self._client_id = base62(uuid.uuid4().int, padding=22) return self.reconnect() @@ -3160,11 +3106,9 @@ def _handle_connack(self): self._reconnect_delay = None if self._protocol == MQTTv5: - self._easy_log( - MQTT_LOG_DEBUG, "Received CONNACK (%s, %s) properties=%s", flags, reason, properties) + self._easy_log(MQTT_LOG_DEBUG, "Received CONNACK (%s, %s) properties=%s", flags, reason, properties) else: - self._easy_log( - MQTT_LOG_DEBUG, "Received CONNACK (%s, %s)", flags, result) + self._easy_log(MQTT_LOG_DEBUG, "Received CONNACK (%s, %s)", flags, result) # it won't be the first successful connect any more self._mqttv5_first_connect = False @@ -3174,18 +3118,15 @@ def _handle_connack(self): if on_connect: flags_dict = {} - flags_dict['session present'] = flags & 0x01 + flags_dict["session present"] = flags & 0x01 with self._in_callback_mutex: try: if self._protocol == MQTTv5: - on_connect(self, self._userdata, - flags_dict, reason, properties) + on_connect(self, self._userdata, flags_dict, reason, properties) else: - on_connect( - self, self._userdata, flags_dict, result) + on_connect(self, self._userdata, flags_dict, result) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_connect: %s', err) + self._easy_log(MQTT_LOG_ERR, "Caught exception in on_connect: %s", err) if not self.suppress_exceptions: raise @@ -3200,15 +3141,7 @@ def _handle_connack(self): if m.qos == 0: with self._in_callback_mutex: # Don't call loop_write after _send_publish() - rc = self._send_publish( - m.mid, - m.topic.encode('utf-8'), - m.payload, - m.qos, - m.retain, - m.dup, - properties=m.properties - ) + rc = self._send_publish(m.mid, m.topic.encode("utf-8"), m.payload, m.qos, m.retain, m.dup, properties=m.properties) if rc != 0: return rc elif m.qos == 1: @@ -3216,15 +3149,7 @@ def _handle_connack(self): self._inflight_messages += 1 m.state = mqtt_ms_wait_for_puback with self._in_callback_mutex: # Don't call loop_write after _send_publish() - rc = self._send_publish( - m.mid, - m.topic.encode('utf-8'), - m.payload, - m.qos, - m.retain, - m.dup, - properties=m.properties - ) + rc = self._send_publish(m.mid, m.topic.encode("utf-8"), m.payload, m.qos, m.retain, m.dup, properties=m.properties) if rc != 0: return rc elif m.qos == 2: @@ -3232,15 +3157,7 @@ def _handle_connack(self): self._inflight_messages += 1 m.state = mqtt_ms_wait_for_pubrec with self._in_callback_mutex: # Don't call loop_write after _send_publish() - rc = self._send_publish( - m.mid, - m.topic.encode('utf-8'), - m.payload, - m.qos, - m.retain, - m.dup, - properties=m.properties - ) + rc = self._send_publish(m.mid, m.topic.encode("utf-8"), m.payload, m.qos, m.retain, m.dup, properties=m.properties) if rc != 0: return rc elif m.state == mqtt_ms_resend_pubrel: @@ -3261,17 +3178,13 @@ def _handle_connack(self): def _handle_disconnect(self): packet_type = DISCONNECT >> 4 reasonCode = properties = None - if self._in_packet['remaining_length'] > 2: + if self._in_packet["remaining_length"] > 2: reasonCode = ReasonCodes(packet_type) - reasonCode.unpack(self._in_packet['packet']) - if self._in_packet['remaining_length'] > 3: + reasonCode.unpack(self._in_packet["packet"]) + if self._in_packet["remaining_length"] > 3: properties = Properties(packet_type) - props, props_len = properties.unpack( - self._in_packet['packet'][1:]) - self._easy_log(MQTT_LOG_DEBUG, "Received DISCONNECT %s %s", - reasonCode, - properties - ) + props, props_len = properties.unpack(self._in_packet["packet"][1:]) + self._easy_log(MQTT_LOG_DEBUG, "Received DISCONNECT %s %s", reasonCode, properties) self._loop_rc_handle(reasonCode, properties) @@ -3280,7 +3193,7 @@ def _handle_disconnect(self): def _handle_suback(self): self._easy_log(MQTT_LOG_DEBUG, "Received SUBACK") pack_format = f"!H{len(self._in_packet['packet']) - 2}s" - (mid, packet) = struct.unpack(pack_format, self._in_packet['packet']) + (mid, packet) = struct.unpack(pack_format, self._in_packet["packet"]) if self._protocol == MQTTv5: properties = Properties(SUBACK >> 4) @@ -3299,14 +3212,11 @@ def _handle_suback(self): with self._in_callback_mutex: # Don't call loop_write after _send_publish() try: if self._protocol == MQTTv5: - on_subscribe( - self, self._userdata, mid, reasoncodes, properties) + on_subscribe(self, self._userdata, mid, reasoncodes, properties) else: - on_subscribe( - self, self._userdata, mid, granted_qos) + on_subscribe(self, self._userdata, mid, granted_qos) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_subscribe: %s', err) + self._easy_log(MQTT_LOG_ERR, "Caught exception in on_subscribe: %s", err) if not self.suppress_exceptions: raise @@ -3315,14 +3225,14 @@ def _handle_suback(self): def _handle_publish(self): rc = 0 - header = self._in_packet['command'] + header = self._in_packet["command"] message = MQTTMessage() message.dup = (header & 0x08) >> 3 message.qos = (header & 0x06) >> 1 - message.retain = (header & 0x01) + message.retain = header & 0x01 pack_format = f"!H{len(self._in_packet['packet']) - 2}s" - (slen, packet) = struct.unpack(pack_format, self._in_packet['packet']) + (slen, packet) = struct.unpack(pack_format, self._in_packet["packet"]) pack_format = f"!{slen}s{len(packet) - slen}s" (topic, packet) = struct.unpack(pack_format, packet) @@ -3334,7 +3244,7 @@ def _handle_publish(self): # representation of the topic for logging. When the user attempts to # access message.topic in the callback, an exception will be raised. try: - print_topic = topic.decode('utf-8') + print_topic = topic.decode("utf-8") except UnicodeDecodeError: print_topic = f"TOPIC WITH INVALID UTF-8: {topic!r}" @@ -3355,15 +3265,24 @@ def _handle_publish(self): self._easy_log( MQTT_LOG_DEBUG, "Received PUBLISH (d%d, q%d, r%d, m%d), '%s', properties=%s, ... (%d bytes)", - message.dup, message.qos, message.retain, message.mid, - print_topic, message.properties, len(message.payload) + message.dup, + message.qos, + message.retain, + message.mid, + print_topic, + message.properties, + len(message.payload), ) else: self._easy_log( MQTT_LOG_DEBUG, "Received PUBLISH (d%d, q%d, r%d, m%d), '%s', ... (%d bytes)", - message.dup, message.qos, message.retain, message.mid, - print_topic, len(message.payload) + message.dup, + message.qos, + message.retain, + message.mid, + print_topic, + len(message.payload), ) message.timestamp = time_func() @@ -3377,7 +3296,6 @@ def _handle_publish(self): else: return self._send_puback(message.mid) elif message.qos == 2: - rc = self._send_pubrec(message.mid) message.state = mqtt_ms_wait_for_pubrel @@ -3390,10 +3308,10 @@ def _handle_publish(self): def ack(self, mid: int, qos: int) -> int: """ - send an acknowledgement for a given message id. (stored in message.mid ) - only useful in QoS=1 and auto_ack=False + send an acknowledgement for a given message id. (stored in message.mid ) + only useful in QoS=1 and auto_ack=False """ - if self._manual_ack : + if self._manual_ack: if qos == 1: return self._send_puback(mid) elif qos == 2: @@ -3403,21 +3321,20 @@ def ack(self, mid: int, qos: int) -> int: def manual_ack_set(self, on): """ - The paho library normally acknowledges messages as soon as they are delivered to the caller. - If manual_ack is turned on, then the caller MUST manually acknowledge every message once - application processing is complete. + The paho library normally acknowledges messages as soon as they are delivered to the caller. + If manual_ack is turned on, then the caller MUST manually acknowledge every message once + application processing is complete. """ self._manual_ack = on - def _handle_pubrel(self): if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 2: + if self._in_packet["remaining_length"] < 2: return MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: + elif self._in_packet["remaining_length"] != 2: return MQTT_ERR_PROTOCOL - mid, = struct.unpack("!H", self._in_packet['packet']) + (mid,) = struct.unpack("!H", self._in_packet["packet"]) self._easy_log(MQTT_LOG_DEBUG, "Received PUBREL (Mid: %d)", mid) with self._in_message_mutex: @@ -3456,7 +3373,7 @@ def _update_inflight(self): m.state = mqtt_ms_wait_for_pubrec rc = self._send_publish( m.mid, - m.topic.encode('utf-8'), + m.topic.encode("utf-8"), m.payload, m.qos, m.retain, @@ -3471,20 +3388,19 @@ def _update_inflight(self): def _handle_pubrec(self): if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 2: + if self._in_packet["remaining_length"] < 2: return MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: + elif self._in_packet["remaining_length"] != 2: return MQTT_ERR_PROTOCOL - mid, = struct.unpack("!H", self._in_packet['packet'][:2]) + (mid,) = struct.unpack("!H", self._in_packet["packet"][:2]) if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] > 2: + if self._in_packet["remaining_length"] > 2: reasonCode = ReasonCodes(PUBREC >> 4) - reasonCode.unpack(self._in_packet['packet'][2:]) - if self._in_packet['remaining_length'] > 3: + reasonCode.unpack(self._in_packet["packet"][2:]) + if self._in_packet["remaining_length"] > 3: properties = Properties(PUBREC >> 4) - props, props_len = properties.unpack( - self._in_packet['packet'][3:]) + props, props_len = properties.unpack(self._in_packet["packet"][3:]) self._easy_log(MQTT_LOG_DEBUG, "Received PUBREC (Mid: %d)", mid) with self._out_message_mutex: @@ -3498,14 +3414,14 @@ def _handle_pubrec(self): def _handle_unsuback(self): if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 4: + if self._in_packet["remaining_length"] < 4: return MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: + elif self._in_packet["remaining_length"] != 2: return MQTT_ERR_PROTOCOL - mid, = struct.unpack("!H", self._in_packet['packet'][:2]) + (mid,) = struct.unpack("!H", self._in_packet["packet"][:2]) if self._protocol == MQTTv5: - packet = self._in_packet['packet'][2:] + packet = self._in_packet["packet"][2:] properties = Properties(UNSUBACK >> 4) props, props_len = properties.unpack(packet) reasoncodes = [] @@ -3522,13 +3438,11 @@ def _handle_unsuback(self): with self._in_callback_mutex: try: if self._protocol == MQTTv5: - on_unsubscribe( - self, self._userdata, mid, properties, reasoncodes) + on_unsubscribe(self, self._userdata, mid, properties, reasoncodes) else: on_unsubscribe(self, self._userdata, mid) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_unsubscribe: %s', err) + self._easy_log(MQTT_LOG_ERR, "Caught exception in on_unsubscribe: %s", err) if not self.suppress_exceptions: raise @@ -3542,13 +3456,11 @@ def _do_on_disconnect(self, rc, properties=None): with self._in_callback_mutex: try: if self._protocol == MQTTv5: - on_disconnect( - self, self._userdata, rc, properties) + on_disconnect(self, self._userdata, rc, properties) else: on_disconnect(self, self._userdata, rc) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_disconnect: %s', err) + self._easy_log(MQTT_LOG_ERR, "Caught exception in on_disconnect: %s", err) if not self.suppress_exceptions: raise @@ -3561,8 +3473,7 @@ def _do_on_publish(self, mid): try: on_publish(self, self._userdata, mid) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_publish: %s', err) + self._easy_log(MQTT_LOG_ERR, "Caught exception in on_publish: %s", err) if not self.suppress_exceptions: raise @@ -3578,22 +3489,21 @@ def _do_on_publish(self, mid): def _handle_pubackcomp(self, cmd): if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 2: + if self._in_packet["remaining_length"] < 2: return MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: + elif self._in_packet["remaining_length"] != 2: return MQTT_ERR_PROTOCOL packet_type = PUBACK if cmd == "PUBACK" else PUBCOMP packet_type = packet_type >> 4 - mid, = struct.unpack("!H", self._in_packet['packet'][:2]) + (mid,) = struct.unpack("!H", self._in_packet["packet"][:2]) if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] > 2: + if self._in_packet["remaining_length"] > 2: reasonCode = ReasonCodes(packet_type) - reasonCode.unpack(self._in_packet['packet'][2:]) - if self._in_packet['remaining_length'] > 3: + reasonCode.unpack(self._in_packet["packet"][2:]) + if self._in_packet["remaining_length"] > 3: properties = Properties(packet_type) - props, props_len = properties.unpack( - self._in_packet['packet'][3:]) + props, props_len = properties.unpack(self._in_packet["packet"][3:]) self._easy_log(MQTT_LOG_DEBUG, "Received %s (Mid: %d)", cmd, mid) with self._out_message_mutex: @@ -3605,7 +3515,6 @@ def _handle_pubackcomp(self, cmd): return MQTT_ERR_SUCCESS def _handle_on_message(self, message): - try: topic = message.topic except UnicodeDecodeError: @@ -3627,12 +3536,7 @@ def _handle_on_message(self, message): try: callback(self, self._userdata, message) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, - 'Caught exception in user defined callback function %s: %s', - callback.__name__, - err - ) + self._easy_log(MQTT_LOG_ERR, "Caught exception in user defined callback function %s: %s", callback.__name__, err) if not self.suppress_exceptions: raise @@ -3641,12 +3545,10 @@ def _handle_on_message(self, message): try: on_message(self, self._userdata, message) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_message: %s', err) + self._easy_log(MQTT_LOG_ERR, "Caught exception in on_message: %s", err) if not self.suppress_exceptions: raise - def _handle_on_connect_fail(self): with self._callback_mutex: on_connect_fail = self.on_connect_fail @@ -3656,8 +3558,7 @@ def _handle_on_connect_fail(self): try: on_connect_fail(self, self._userdata) except Exception as err: - self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_connect_fail: %s', err) + self._easy_log(MQTT_LOG_ERR, "Caught exception in on_connect_fail: %s", err) def _thread_main(self): self.loop_forever(retry_first_connection=True) @@ -3677,18 +3578,14 @@ def _reconnect_wait(self): target_time = now + self._reconnect_delay remaining = target_time - now - while (self._state != mqtt_cs_disconnecting - and not self._thread_terminate - and remaining > 0): - + while self._state != mqtt_cs_disconnecting and not self._thread_terminate and remaining > 0: time.sleep(min(remaining, 1)) remaining = target_time - time_func() @staticmethod def _proxy_is_valid(p): def check(t, a): - return (socks is not None and - t in set([socks.HTTP, socks.SOCKS4, socks.SOCKS5]) and a) + return socks is not None and t in set([socks.HTTP, socks.SOCKS4, socks.SOCKS5]) and a if isinstance(p, dict): return check(p.get("proxy_type"), p.get("proxy_addr")) @@ -3708,32 +3605,22 @@ def _get_proxy(self): # Next, check for an mqtt_proxy environment variable as long as the host # we're trying to connect to isn't listed under the no_proxy environment # variable (matches built-in module urllib's behavior) - if not (hasattr(urllib.request, "proxy_bypass") and - urllib.request.proxy_bypass(self._host)): + if not (hasattr(urllib.request, "proxy_bypass") and urllib.request.proxy_bypass(self._host)): env_proxies = urllib.request.getproxies() if "mqtt" in env_proxies: parts = urllib.parse.urlparse(env_proxies["mqtt"]) if parts.scheme == "http": - proxy = { - "proxy_type": socks.HTTP, - "proxy_addr": parts.hostname, - "proxy_port": parts.port - } + proxy = {"proxy_type": socks.HTTP, "proxy_addr": parts.hostname, "proxy_port": parts.port} return proxy elif parts.scheme == "socks": - proxy = { - "proxy_type": socks.SOCKS5, - "proxy_addr": parts.hostname, - "proxy_port": parts.port - } + proxy = {"proxy_type": socks.SOCKS5, "proxy_addr": parts.hostname, "proxy_port": parts.port} return proxy # Finally, check if the user has monkeypatched the PySocks library with # a default proxy socks_default = socks.get_default_proxy() if self._proxy_is_valid(socks_default): - proxy_keys = ("proxy_type", "proxy_addr", "proxy_port", - "proxy_rdns", "proxy_username", "proxy_password") + proxy_keys = ("proxy_type", "proxy_addr", "proxy_port", "proxy_rdns", "proxy_username", "proxy_password") return dict(zip(proxy_keys, socks_default)) # If we didn't find a proxy through any of the above methods, return @@ -3757,10 +3644,9 @@ class WebsocketWrapper: OPCODE_BINARY = 0x2 OPCODE_CONNCLOSE = 0x8 OPCODE_PING = 0x9 - OPCODE_PONG = 0xa + OPCODE_PONG = 0xA def __init__(self, socket, host, port, is_ssl, path, extra_headers): - self.connected = False self._ssl = is_ssl @@ -3779,12 +3665,10 @@ def __init__(self, socket, host, port, is_ssl, path, extra_headers): self._do_handshake(extra_headers) def __del__(self): - self._sendbuffer = None self._readbuffer = None def _do_handshake(self, extra_headers): - sec_websocket_key = uuid.uuid4().bytes sec_websocket_key = base64.b64encode(sec_websocket_key) @@ -3805,11 +3689,13 @@ def _do_handshake(self, extra_headers): elif callable(extra_headers): websocket_headers = extra_headers(websocket_headers) - header = "\r\n".join([ - f"GET {self._path} HTTP/1.1", - "\r\n".join(f"{i}: {j}" for i, j in websocket_headers.items()), - "\r\n", - ]).encode("utf8") + header = "\r\n".join( + [ + f"GET {self._path} HTTP/1.1", + "\r\n".join(f"{i}: {j}" for i, j in websocket_headers.items()), + "\r\n", + ] + ).encode("utf8") self._socket.send(header) @@ -3826,29 +3712,26 @@ def _do_handshake(self, extra_headers): if byte == b"\n": if len(self._readbuffer) > 2: # check upgrade - if b"connection" in str(self._readbuffer).lower().encode('utf-8'): - if b"upgrade" not in str(self._readbuffer).lower().encode('utf-8'): - raise WebsocketConnectionError( - "WebSocket handshake error, connection not upgraded") + if b"connection" in str(self._readbuffer).lower().encode("utf-8"): + if b"upgrade" not in str(self._readbuffer).lower().encode("utf-8"): + raise WebsocketConnectionError("WebSocket handshake error, connection not upgraded") else: has_upgrade = True # check key hash - if b"sec-websocket-accept" in str(self._readbuffer).lower().encode('utf-8'): + if b"sec-websocket-accept" in str(self._readbuffer).lower().encode("utf-8"): GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - server_hash = self._readbuffer.decode( - 'utf-8').split(": ", 1)[1] - server_hash = server_hash.strip().encode('utf-8') + server_hash = self._readbuffer.decode("utf-8").split(": ", 1)[1] + server_hash = server_hash.strip().encode("utf-8") - client_hash = sec_websocket_key.decode('utf-8') + GUID + client_hash = sec_websocket_key.decode("utf-8") + GUID # Use of SHA-1 is OK here; it's according to the Websocket spec. - client_hash = hashlib.sha1(client_hash.encode('utf-8')) # noqa: S324 + client_hash = hashlib.sha1(client_hash.encode("utf-8")) # noqa: S324 client_hash = base64.b64encode(client_hash.digest()) if server_hash != client_hash: - raise WebsocketConnectionError( - "WebSocket handshake error, invalid secret key") + raise WebsocketConnectionError("WebSocket handshake error, invalid secret key") else: has_secret = True else: @@ -3869,7 +3752,6 @@ def _do_handshake(self, extra_headers): self.connected = True def _create_frame(self, opcode, data, do_masking=1): - header = bytearray() length = len(data) @@ -3901,11 +3783,9 @@ def _create_frame(self, opcode, data, do_masking=1): return header + data def _buffered_read(self, length): - # try to recv and store needed bytes wanted_bytes = length - (len(self._readbuffer) - self._readbuffer_head) if wanted_bytes > 0: - data = self._socket.recv(wanted_bytes) if not data: @@ -3917,13 +3797,11 @@ def _buffered_read(self, length): raise BlockingIOError self._readbuffer_head += length - return self._readbuffer[self._readbuffer_head - length:self._readbuffer_head] + return self._readbuffer[self._readbuffer_head - length : self._readbuffer_head] def _recv_impl(self, length): - # try to decode websocket payload part from data try: - self._readbuffer_head = 0 result = None @@ -3934,22 +3812,20 @@ def _recv_impl(self, length): header1 = self._buffered_read(1) header2 = self._buffered_read(1) - opcode = (header1[0] & 0x0f) + opcode = header1[0] & 0x0F maskbit = (header2[0] & 0x80) == 0x80 - lengthbits = (header2[0] & 0x7f) + lengthbits = header2[0] & 0x7F payload_length = lengthbits mask_key = None # read length - if lengthbits == 0x7e: - + if lengthbits == 0x7E: value = self._buffered_read(2) - payload_length, = struct.unpack("!H", value) - - elif lengthbits == 0x7f: + (payload_length,) = struct.unpack("!H", value) + elif lengthbits == 0x7F: value = self._buffered_read(8) - payload_length, = struct.unpack("!Q", value) + (payload_length,) = struct.unpack("!Q", value) # read mask if maskbit: @@ -3981,34 +3857,29 @@ def _recv_impl(self, length): # respond to non-binary opcodes, their arrival is not guaranteed because of non-blocking sockets if opcode == WebsocketWrapper.OPCODE_CONNCLOSE: - frame = self._create_frame( - WebsocketWrapper.OPCODE_CONNCLOSE, payload, 0) + frame = self._create_frame(WebsocketWrapper.OPCODE_CONNCLOSE, payload, 0) self._socket.send(frame) if opcode == WebsocketWrapper.OPCODE_PING: - frame = self._create_frame( - WebsocketWrapper.OPCODE_PONG, payload, 0) + frame = self._create_frame(WebsocketWrapper.OPCODE_PONG, payload, 0) self._socket.send(frame) # This isn't *proper* handling of continuation frames, but given # that we only support binary frames, it is *probably* good enough. - if (opcode == WebsocketWrapper.OPCODE_BINARY or opcode == WebsocketWrapper.OPCODE_CONTINUATION) \ - and payload_length > 0: + if (opcode == WebsocketWrapper.OPCODE_BINARY or opcode == WebsocketWrapper.OPCODE_CONTINUATION) and payload_length > 0: return result else: raise BlockingIOError except ConnectionError: self.connected = False - return b'' + return b"" def _send_impl(self, data): - # if previous frame was sent successfully if len(self._sendbuffer) == 0: # create websocket frame - frame = self._create_frame( - WebsocketWrapper.OPCODE_BINARY, bytearray(data)) + frame = self._create_frame(WebsocketWrapper.OPCODE_BINARY, bytearray(data)) self._sendbuffer.extend(frame) self._requested_size = len(data) diff --git a/src/paho/mqtt/matcher.py b/src/paho/mqtt/matcher.py index b73c13ac..37940a30 100644 --- a/src/paho/mqtt/matcher.py +++ b/src/paho/mqtt/matcher.py @@ -7,7 +7,7 @@ class MQTTMatcher: some topic name.""" class Node: - __slots__ = '_children', '_content' + __slots__ = "_children", "_content" def __init__(self): self._children = {} @@ -20,7 +20,7 @@ def __setitem__(self, key, value): """Add a topic filter :key to the prefix tree and associate it to :value""" node = self._root - for sym in key.split('/'): + for sym in key.split("/"): node = node._children.setdefault(sym, self.Node()) node._content = value @@ -28,7 +28,7 @@ def __getitem__(self, key): """Retrieve the value associated with some topic filter :key""" try: node = self._root - for sym in key.split('/'): + for sym in key.split("/"): node = node._children[sym] if node._content is None: raise KeyError(key) @@ -41,9 +41,9 @@ def __delitem__(self, key): lst = [] try: parent, node = None, self._root - for k in key.split('/'): - parent, node = node, node._children[k] - lst.append((parent, k, node)) + for k in key.split("/"): + parent, node = node, node._children[k] + lst.append((parent, k, node)) # TODO node._content = None except KeyError as ke: @@ -51,14 +51,15 @@ def __delitem__(self, key): else: # cleanup for parent, k, node in reversed(lst): if node._children or node._content is not None: - break + break del parent._children[k] def iter_match(self, topic): """Return an iterator on all values associated with filters that match the :topic""" - lst = topic.split('/') - normal = not topic.startswith('$') + lst = topic.split("/") + normal = not topic.startswith("$") + def rec(node, i=0): if i == len(lst): if node._content is not None: @@ -68,11 +69,12 @@ def rec(node, i=0): if part in node._children: for content in rec(node._children[part], i + 1): yield content - if '+' in node._children and (normal or i > 0): - for content in rec(node._children['+'], i + 1): + if "+" in node._children and (normal or i > 0): + for content in rec(node._children["+"], i + 1): yield content - if '#' in node._children and (normal or i > 0): - content = node._children['#']._content + if "#" in node._children and (normal or i > 0): + content = node._children["#"]._content if content is not None: yield content + return rec(self._root) diff --git a/src/paho/mqtt/packettypes.py b/src/paho/mqtt/packettypes.py index 2fd6a1b5..7ce4a3f0 100644 --- a/src/paho/mqtt/packettypes.py +++ b/src/paho/mqtt/packettypes.py @@ -30,14 +30,26 @@ class PacketTypes: indexes = range(1, 16) # Packet types - CONNECT, CONNACK, PUBLISH, PUBACK, PUBREC, PUBREL, \ - PUBCOMP, SUBSCRIBE, SUBACK, UNSUBSCRIBE, UNSUBACK, \ - PINGREQ, PINGRESP, DISCONNECT, AUTH = indexes + CONNECT, CONNACK, PUBLISH, PUBACK, PUBREC, PUBREL, PUBCOMP, SUBSCRIBE, SUBACK, UNSUBSCRIBE, UNSUBACK, PINGREQ, PINGRESP, DISCONNECT, AUTH = indexes # Dummy packet type for properties use - will delay only applies to will WILLMESSAGE = 99 - Names = [ "reserved", \ - "Connect", "Connack", "Publish", "Puback", "Pubrec", "Pubrel", \ - "Pubcomp", "Subscribe", "Suback", "Unsubscribe", "Unsuback", \ - "Pingreq", "Pingresp", "Disconnect", "Auth"] + Names = [ + "reserved", + "Connect", + "Connack", + "Publish", + "Puback", + "Pubrec", + "Pubrel", + "Pubcomp", + "Subscribe", + "Suback", + "Unsubscribe", + "Unsuback", + "Pingreq", + "Pingresp", + "Disconnect", + "Auth", + ] diff --git a/src/paho/mqtt/properties.py b/src/paho/mqtt/properties.py index e5e19103..eddca0fe 100644 --- a/src/paho/mqtt/properties.py +++ b/src/paho/mqtt/properties.py @@ -64,17 +64,17 @@ def readUTF(buffer, maxlen): maxlen -= 2 if length > maxlen: raise MalformedPacket("Length delimited string too long") - buf = buffer[2:2+length].decode("utf-8") + buf = buffer[2 : 2 + length].decode("utf-8") # look for chars which are invalid for MQTT - for c in buf: # look for D800-DFFF in the UTF string + for c in buf: # look for D800-DFFF in the UTF string ord_c = ord(c) if ord_c >= 0xD800 and ord_c <= 0xDFFF: raise MalformedPacket("[MQTT-1.5.4-1] D800-DFFF found in UTF-8 data") - if ord_c == 0x00: # look for null in the UTF string + if ord_c == 0x00: # look for null in the UTF string raise MalformedPacket("[MQTT-1.5.4-2] Null found in UTF-8 data") if ord_c == 0xFEFF: raise MalformedPacket("[MQTT-1.5.4-3] U+FEFF in UTF-8 data") - return buf, length+2 + return buf, length + 2 def writeBytes(buffer): @@ -83,7 +83,7 @@ def writeBytes(buffer): def readBytes(buffer): length = readInt16(buffer) - return buffer[2:2+length], length+2 + return buffer[2 : 2 + length], length + 2 class VariableByteIntegers: # Variable Byte Integer @@ -96,12 +96,12 @@ class VariableByteIntegers: # Variable Byte Integer @staticmethod def encode(x): """ - Convert an integer 0 <= x <= 268435455 into multi-byte format. - Returns the buffer convered from the integer. + Convert an integer 0 <= x <= 268435455 into multi-byte format. + Returns the buffer convered from the integer. """ if not 0 <= x <= 268435455: raise ValueError(f"Value {x!r} must be in range 0-268435455") - buffer = b'' + buffer = b"" while 1: digit = x % 128 x //= 128 @@ -115,10 +115,10 @@ def encode(x): @staticmethod def decode(buffer): """ - Get the value of a multi-byte integer from a buffer - Return the value, and the number of bytes used. + Get the value of a multi-byte integer from a buffer + Return the value, and the number of bytes used. - [MQTT-1.5.5-1] the encoded value MUST use the minimum number of bytes necessary to represent the value + [MQTT-1.5.5-1] the encoded value MUST use the minimum number of bytes necessary to represent the value """ multiplier = 1 value = 0 @@ -155,8 +155,7 @@ class Properties: def __init__(self, packetType): self.packetType = packetType - self.types = ["Byte", "Two Byte Integer", "Four Byte Integer", "Variable Byte Integer", - "Binary Data", "UTF-8 Encoded String", "UTF-8 String Pair"] + self.types = ["Byte", "Two Byte Integer", "Four Byte Integer", "Variable Byte Integer", "Binary Data", "UTF-8 Encoded String", "UTF-8 String Pair"] self.names = { "Payload Format Indicator": 1, @@ -185,7 +184,7 @@ def __init__(self, packetType): "Maximum Packet Size": 39, "Wildcard Subscription Available": 40, "Subscription Identifier Available": 41, - "Shared Subscription Available": 42 + "Shared Subscription Available": 42, } self.properties = { @@ -196,43 +195,56 @@ def __init__(self, packetType): 3: (self.types.index("UTF-8 Encoded String"), [PacketTypes.PUBLISH, PacketTypes.WILLMESSAGE]), 8: (self.types.index("UTF-8 Encoded String"), [PacketTypes.PUBLISH, PacketTypes.WILLMESSAGE]), 9: (self.types.index("Binary Data"), [PacketTypes.PUBLISH, PacketTypes.WILLMESSAGE]), - 11: (self.types.index("Variable Byte Integer"), - [PacketTypes.PUBLISH, PacketTypes.SUBSCRIBE]), - 17: (self.types.index("Four Byte Integer"), - [PacketTypes.CONNECT, PacketTypes.CONNACK, PacketTypes.DISCONNECT]), + 11: (self.types.index("Variable Byte Integer"), [PacketTypes.PUBLISH, PacketTypes.SUBSCRIBE]), + 17: (self.types.index("Four Byte Integer"), [PacketTypes.CONNECT, PacketTypes.CONNACK, PacketTypes.DISCONNECT]), 18: (self.types.index("UTF-8 Encoded String"), [PacketTypes.CONNACK]), 19: (self.types.index("Two Byte Integer"), [PacketTypes.CONNACK]), - 21: (self.types.index("UTF-8 Encoded String"), - [PacketTypes.CONNECT, PacketTypes.CONNACK, PacketTypes.AUTH]), - 22: (self.types.index("Binary Data"), - [PacketTypes.CONNECT, PacketTypes.CONNACK, PacketTypes.AUTH]), - 23: (self.types.index("Byte"), - [PacketTypes.CONNECT]), + 21: (self.types.index("UTF-8 Encoded String"), [PacketTypes.CONNECT, PacketTypes.CONNACK, PacketTypes.AUTH]), + 22: (self.types.index("Binary Data"), [PacketTypes.CONNECT, PacketTypes.CONNACK, PacketTypes.AUTH]), + 23: (self.types.index("Byte"), [PacketTypes.CONNECT]), 24: (self.types.index("Four Byte Integer"), [PacketTypes.WILLMESSAGE]), 25: (self.types.index("Byte"), [PacketTypes.CONNECT]), 26: (self.types.index("UTF-8 Encoded String"), [PacketTypes.CONNACK]), - 28: (self.types.index("UTF-8 Encoded String"), - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]), - 31: (self.types.index("UTF-8 Encoded String"), - [PacketTypes.CONNACK, PacketTypes.PUBACK, PacketTypes.PUBREC, - PacketTypes.PUBREL, PacketTypes.PUBCOMP, PacketTypes.SUBACK, - PacketTypes.UNSUBACK, PacketTypes.DISCONNECT, PacketTypes.AUTH]), - 33: (self.types.index("Two Byte Integer"), - [PacketTypes.CONNECT, PacketTypes.CONNACK]), - 34: (self.types.index("Two Byte Integer"), - [PacketTypes.CONNECT, PacketTypes.CONNACK]), + 28: (self.types.index("UTF-8 Encoded String"), [PacketTypes.CONNACK, PacketTypes.DISCONNECT]), + 31: ( + self.types.index("UTF-8 Encoded String"), + [ + PacketTypes.CONNACK, + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.PUBREL, + PacketTypes.PUBCOMP, + PacketTypes.SUBACK, + PacketTypes.UNSUBACK, + PacketTypes.DISCONNECT, + PacketTypes.AUTH, + ], + ), + 33: (self.types.index("Two Byte Integer"), [PacketTypes.CONNECT, PacketTypes.CONNACK]), + 34: (self.types.index("Two Byte Integer"), [PacketTypes.CONNECT, PacketTypes.CONNACK]), 35: (self.types.index("Two Byte Integer"), [PacketTypes.PUBLISH]), 36: (self.types.index("Byte"), [PacketTypes.CONNACK]), 37: (self.types.index("Byte"), [PacketTypes.CONNACK]), - 38: (self.types.index("UTF-8 String Pair"), - [PacketTypes.CONNECT, PacketTypes.CONNACK, - PacketTypes.PUBLISH, PacketTypes.PUBACK, - PacketTypes.PUBREC, PacketTypes.PUBREL, PacketTypes.PUBCOMP, - PacketTypes.SUBSCRIBE, PacketTypes.SUBACK, - PacketTypes.UNSUBSCRIBE, PacketTypes.UNSUBACK, - PacketTypes.DISCONNECT, PacketTypes.AUTH, PacketTypes.WILLMESSAGE]), - 39: (self.types.index("Four Byte Integer"), - [PacketTypes.CONNECT, PacketTypes.CONNACK]), + 38: ( + self.types.index("UTF-8 String Pair"), + [ + PacketTypes.CONNECT, + PacketTypes.CONNACK, + PacketTypes.PUBLISH, + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.PUBREL, + PacketTypes.PUBCOMP, + PacketTypes.SUBSCRIBE, + PacketTypes.SUBACK, + PacketTypes.UNSUBSCRIBE, + PacketTypes.UNSUBACK, + PacketTypes.DISCONNECT, + PacketTypes.AUTH, + PacketTypes.WILLMESSAGE, + ], + ), + 39: (self.types.index("Four Byte Integer"), [PacketTypes.CONNECT, PacketTypes.CONNACK]), 40: (self.types.index("Byte"), [PacketTypes.CONNACK]), 41: (self.types.index("Byte"), [PacketTypes.CONNACK]), 42: (self.types.index("Byte"), [PacketTypes.CONNACK]), @@ -245,44 +257,34 @@ def getIdentFromName(self, compressedName): # return the identifier corresponding to the property name result = -1 for name in self.names.keys(): - if compressedName == name.replace(' ', ''): + if compressedName == name.replace(" ", ""): result = self.names[name] break return result def __setattr__(self, name, value): - name = name.replace(' ', '') + name = name.replace(" ", "") privateVars = ["packetType", "types", "names", "properties"] if name in privateVars: object.__setattr__(self, name, value) else: # the name could have spaces in, or not. Remove spaces before assignment - if name not in [aname.replace(' ', '') for aname in self.names.keys()]: - raise MQTTException( - f"Property name must be one of {self.names.keys()}") + if name not in [aname.replace(" ", "") for aname in self.names.keys()]: + raise MQTTException(f"Property name must be one of {self.names.keys()}") # check that this attribute applies to the packet type if self.packetType not in self.properties[self.getIdentFromName(name)][1]: raise MQTTException(f"Property {name} does not apply to packet type {PacketTypes.Names[self.packetType]}") # Check for forbidden values if not isinstance(value, list): - if name in ["ReceiveMaximum", "TopicAlias"] \ - and (value < 1 or value > 65535): - + if name in ["ReceiveMaximum", "TopicAlias"] and (value < 1 or value > 65535): raise MQTTException(f"{name} property value must be in the range 1-65535") - elif name in ["TopicAliasMaximum"] \ - and (value < 0 or value > 65535): - + elif name in ["TopicAliasMaximum"] and (value < 0 or value > 65535): raise MQTTException(f"{name} property value must be in the range 0-65535") - elif name in ["MaximumPacketSize", "SubscriptionIdentifier"] \ - and (value < 1 or value > 268435455): - + elif name in ["MaximumPacketSize", "SubscriptionIdentifier"] and (value < 1 or value > 268435455): raise MQTTException(f"{name} property value must be in the range 1-268435455") - elif name in ["RequestResponseInformation", "RequestProblemInformation", "PayloadFormatIndicator"] \ - and (value != 0 and value != 1): - - raise MQTTException( - f"{name} property value must be 0 or 1") + elif name in ["RequestResponseInformation", "RequestProblemInformation", "PayloadFormatIndicator"] and (value != 0 and value != 1): + raise MQTTException(f"{name} property value must be 0 or 1") if self.allowsMultiple(name): if not isinstance(value, list): @@ -295,7 +297,7 @@ def __str__(self): buffer = "[" first = True for name in self.names.keys(): - compressedName = name.replace(' ', '') + compressedName = name.replace(" ", "") if hasattr(self, compressedName): if not first: buffer += ", " @@ -307,10 +309,10 @@ def __str__(self): def json(self): data = {} for name in self.names.keys(): - compressedName = name.replace(' ', '') + compressedName = name.replace(" ", "") if hasattr(self, compressedName): val = getattr(self, compressedName) - if compressedName == 'CorrelationData' and isinstance(val, bytes): + if compressedName == "CorrelationData" and isinstance(val, bytes): data[compressedName] = val.hex() else: data[compressedName] = val @@ -319,7 +321,7 @@ def json(self): def isEmpty(self): rc = True for name in self.names.keys(): - compressedName = name.replace(' ', '') + compressedName = name.replace(" ", "") if hasattr(self, compressedName): rc = False break @@ -327,7 +329,7 @@ def isEmpty(self): def clear(self): for name in self.names.keys(): - compressedName = name.replace(' ', '') + compressedName = name.replace(" ", "") if hasattr(self, compressedName): delattr(self, compressedName) @@ -354,17 +356,15 @@ def pack(self): # serialize properties into buffer for sending over network buffer = b"" for name in self.names.keys(): - compressedName = name.replace(' ', '') + compressedName = name.replace(" ", "") if hasattr(self, compressedName): identifier = self.getIdentFromName(compressedName) attr_type = self.properties[identifier][0] if self.allowsMultiple(compressedName): for prop in getattr(self, compressedName): - buffer += self.writeProperty(identifier, - attr_type, prop) + buffer += self.writeProperty(identifier, attr_type, prop) else: - buffer += self.writeProperty(identifier, attr_type, - getattr(self, compressedName)) + buffer += self.writeProperty(identifier, attr_type, getattr(self, compressedName)) return VariableByteIntegers.encode(len(buffer)) + buffer def readProperty(self, buffer, type, propslen): @@ -405,19 +405,16 @@ def unpack(self, buffer): buffer = buffer[VBIlen:] # strip the bytes used by the VBI propslenleft = propslen while propslenleft > 0: # properties length is 0 if there are none - identifier, VBIlen2 = VariableByteIntegers.decode( - buffer) # property identifier + identifier, VBIlen2 = VariableByteIntegers.decode(buffer) # property identifier buffer = buffer[VBIlen2:] # strip the bytes used by the VBI propslenleft -= VBIlen2 attr_type = self.properties[identifier][0] - value, valuelen = self.readProperty( - buffer, attr_type, propslenleft) + value, valuelen = self.readProperty(buffer, attr_type, propslenleft) buffer = buffer[valuelen:] # strip the bytes used by the value propslenleft -= valuelen propname = self.getNameFromIdent(identifier) - compressedName = propname.replace(' ', '') + compressedName = propname.replace(" ", "") if not self.allowsMultiple(compressedName) and hasattr(self, compressedName): - raise MQTTException( - f"Property '{property}' must not exist more than once") + raise MQTTException(f"Property '{property}' must not exist more than once") setattr(self, propname, value) return self, propslen + VBIlen diff --git a/src/paho/mqtt/publish.py b/src/paho/mqtt/publish.py index 38138585..054bc24a 100644 --- a/src/paho/mqtt/publish.py +++ b/src/paho/mqtt/publish.py @@ -36,12 +36,12 @@ def _do_publish(client): elif isinstance(message, (tuple, list)): client.publish(*message) else: - raise TypeError('message must be a dict, tuple, or list') + raise TypeError("message must be a dict, tuple, or list") def _on_connect(client, userdata, flags, rc): """Internal callback""" - #pylint: disable=invalid-name, unused-argument + # pylint: disable=invalid-name, unused-argument if rc == 0: if len(userdata) > 0: @@ -49,13 +49,15 @@ def _on_connect(client, userdata, flags, rc): else: raise mqtt.MQTTException(paho.connack_string(rc)) + def _on_connect_v5(client, userdata, flags, rc, properties): """Internal v5 callback""" _on_connect(client, userdata, flags, rc) + def _on_publish(client, userdata, mid): """Internal callback""" - #pylint: disable=unused-argument + # pylint: disable=unused-argument if len(userdata) == 0: client.disconnect() @@ -63,9 +65,9 @@ def _on_publish(client, userdata, mid): _do_publish(client) -def multiple(msgs, hostname="localhost", port=1883, client_id="", keepalive=60, - will=None, auth=None, tls=None, protocol=paho.MQTTv311, - transport="tcp", proxy_args=None): +def multiple( + msgs, hostname="localhost", port=1883, client_id="", keepalive=60, will=None, auth=None, tls=None, protocol=paho.MQTTv311, transport="tcp", proxy_args=None +): """Publish multiple messages to a broker, then disconnect cleanly. This function creates an MQTT client, connects to a broker and publishes a @@ -129,11 +131,9 @@ def multiple(msgs, hostname="localhost", port=1883, client_id="", keepalive=60, """ if not isinstance(msgs, Iterable): - raise TypeError('msgs must be an iterable') - + raise TypeError("msgs must be an iterable") - client = paho.Client(client_id=client_id, userdata=collections.deque(msgs), - protocol=protocol, transport=transport) + client = paho.Client(client_id=client_id, userdata=collections.deque(msgs), protocol=protocol, transport=transport) client.on_publish = _on_publish if protocol == mqtt.client.MQTTv5: @@ -145,20 +145,19 @@ def multiple(msgs, hostname="localhost", port=1883, client_id="", keepalive=60, client.proxy_set(**proxy_args) if auth: - username = auth.get('username') + username = auth.get("username") if username: - password = auth.get('password') + password = auth.get("password") client.username_pw_set(username, password) else: - raise KeyError("The 'username' key was not found, this is " - "required for auth") + raise KeyError("The 'username' key was not found, this is " "required for auth") if will is not None: client.will_set(**will) if tls is not None: if isinstance(tls, dict): - insecure = tls.pop('insecure', False) + insecure = tls.pop("insecure", False) client.tls_set(**tls) if insecure: # Must be set *after* the `client.tls_set()` call since it sets @@ -172,9 +171,22 @@ def multiple(msgs, hostname="localhost", port=1883, client_id="", keepalive=60, client.loop_forever() -def single(topic, payload=None, qos=0, retain=False, hostname="localhost", - port=1883, client_id="", keepalive=60, will=None, auth=None, - tls=None, protocol=paho.MQTTv311, transport="tcp", proxy_args=None): +def single( + topic, + payload=None, + qos=0, + retain=False, + hostname="localhost", + port=1883, + client_id="", + keepalive=60, + will=None, + auth=None, + tls=None, + protocol=paho.MQTTv311, + transport="tcp", + proxy_args=None, +): """Publish a single message to a broker, then disconnect cleanly. This function creates an MQTT client, connects to a broker and publishes a @@ -230,7 +242,6 @@ def single(topic, payload=None, qos=0, retain=False, hostname="localhost", proxy_args: a dictionary that will be given to the client. """ - msg = {'topic':topic, 'payload':payload, 'qos':qos, 'retain':retain} + msg = {"topic": topic, "payload": payload, "qos": qos, "retain": retain} - multiple([msg], hostname, port, client_id, keepalive, will, auth, tls, - protocol, transport, proxy_args) + multiple([msg], hostname, port, client_id, keepalive, will, auth, tls, protocol, transport, proxy_args) diff --git a/src/paho/mqtt/reasoncodes.py b/src/paho/mqtt/reasoncodes.py index 69a313f7..107083bf 100644 --- a/src/paho/mqtt/reasoncodes.py +++ b/src/paho/mqtt/reasoncodes.py @@ -43,80 +43,76 @@ def __init__(self, packetType, aName="Success", identifier=-1): self.packetType = packetType self.names = { - 0: {"Success": [PacketTypes.CONNACK, PacketTypes.PUBACK, - PacketTypes.PUBREC, PacketTypes.PUBREL, PacketTypes.PUBCOMP, - PacketTypes.UNSUBACK, PacketTypes.AUTH], + 0: { + "Success": [ + PacketTypes.CONNACK, + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.PUBREL, + PacketTypes.PUBCOMP, + PacketTypes.UNSUBACK, + PacketTypes.AUTH, + ], "Normal disconnection": [PacketTypes.DISCONNECT], - "Granted QoS 0": [PacketTypes.SUBACK]}, + "Granted QoS 0": [PacketTypes.SUBACK], + }, 1: {"Granted QoS 1": [PacketTypes.SUBACK]}, 2: {"Granted QoS 2": [PacketTypes.SUBACK]}, 4: {"Disconnect with will message": [PacketTypes.DISCONNECT]}, - 16: {"No matching subscribers": - [PacketTypes.PUBACK, PacketTypes.PUBREC]}, + 16: {"No matching subscribers": [PacketTypes.PUBACK, PacketTypes.PUBREC]}, 17: {"No subscription found": [PacketTypes.UNSUBACK]}, 24: {"Continue authentication": [PacketTypes.AUTH]}, 25: {"Re-authenticate": [PacketTypes.AUTH]}, - 128: {"Unspecified error": [PacketTypes.CONNACK, PacketTypes.PUBACK, - PacketTypes.PUBREC, PacketTypes.SUBACK, PacketTypes.UNSUBACK, - PacketTypes.DISCONNECT], }, - 129: {"Malformed packet": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 130: {"Protocol error": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 131: {"Implementation specific error": [PacketTypes.CONNACK, - PacketTypes.PUBACK, PacketTypes.PUBREC, PacketTypes.SUBACK, - PacketTypes.UNSUBACK, PacketTypes.DISCONNECT], }, + 128: { + "Unspecified error": [PacketTypes.CONNACK, PacketTypes.PUBACK, PacketTypes.PUBREC, PacketTypes.SUBACK, PacketTypes.UNSUBACK, PacketTypes.DISCONNECT], + }, + 129: {"Malformed packet": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 130: {"Protocol error": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 131: { + "Implementation specific error": [ + PacketTypes.CONNACK, + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.SUBACK, + PacketTypes.UNSUBACK, + PacketTypes.DISCONNECT, + ], + }, 132: {"Unsupported protocol version": [PacketTypes.CONNACK]}, 133: {"Client identifier not valid": [PacketTypes.CONNACK]}, 134: {"Bad user name or password": [PacketTypes.CONNACK]}, - 135: {"Not authorized": [PacketTypes.CONNACK, PacketTypes.PUBACK, - PacketTypes.PUBREC, PacketTypes.SUBACK, PacketTypes.UNSUBACK, - PacketTypes.DISCONNECT], }, + 135: { + "Not authorized": [PacketTypes.CONNACK, PacketTypes.PUBACK, PacketTypes.PUBREC, PacketTypes.SUBACK, PacketTypes.UNSUBACK, PacketTypes.DISCONNECT], + }, 136: {"Server unavailable": [PacketTypes.CONNACK]}, 137: {"Server busy": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, 138: {"Banned": [PacketTypes.CONNACK]}, 139: {"Server shutting down": [PacketTypes.DISCONNECT]}, - 140: {"Bad authentication method": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 140: {"Bad authentication method": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, 141: {"Keep alive timeout": [PacketTypes.DISCONNECT]}, 142: {"Session taken over": [PacketTypes.DISCONNECT]}, - 143: {"Topic filter invalid": - [PacketTypes.SUBACK, PacketTypes.UNSUBACK, PacketTypes.DISCONNECT]}, - 144: {"Topic name invalid": - [PacketTypes.CONNACK, PacketTypes.PUBACK, - PacketTypes.PUBREC, PacketTypes.DISCONNECT]}, - 145: {"Packet identifier in use": - [PacketTypes.PUBACK, PacketTypes.PUBREC, - PacketTypes.SUBACK, PacketTypes.UNSUBACK]}, - 146: {"Packet identifier not found": - [PacketTypes.PUBREL, PacketTypes.PUBCOMP]}, + 143: {"Topic filter invalid": [PacketTypes.SUBACK, PacketTypes.UNSUBACK, PacketTypes.DISCONNECT]}, + 144: {"Topic name invalid": [PacketTypes.CONNACK, PacketTypes.PUBACK, PacketTypes.PUBREC, PacketTypes.DISCONNECT]}, + 145: {"Packet identifier in use": [PacketTypes.PUBACK, PacketTypes.PUBREC, PacketTypes.SUBACK, PacketTypes.UNSUBACK]}, + 146: {"Packet identifier not found": [PacketTypes.PUBREL, PacketTypes.PUBCOMP]}, 147: {"Receive maximum exceeded": [PacketTypes.DISCONNECT]}, 148: {"Topic alias invalid": [PacketTypes.DISCONNECT]}, 149: {"Packet too large": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, 150: {"Message rate too high": [PacketTypes.DISCONNECT]}, - 151: {"Quota exceeded": [PacketTypes.CONNACK, PacketTypes.PUBACK, - PacketTypes.PUBREC, PacketTypes.SUBACK, PacketTypes.DISCONNECT], }, + 151: { + "Quota exceeded": [PacketTypes.CONNACK, PacketTypes.PUBACK, PacketTypes.PUBREC, PacketTypes.SUBACK, PacketTypes.DISCONNECT], + }, 152: {"Administrative action": [PacketTypes.DISCONNECT]}, - 153: {"Payload format invalid": - [PacketTypes.PUBACK, PacketTypes.PUBREC, PacketTypes.DISCONNECT]}, - 154: {"Retain not supported": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 155: {"QoS not supported": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 156: {"Use another server": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 157: {"Server moved": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 158: {"Shared subscription not supported": - [PacketTypes.SUBACK, PacketTypes.DISCONNECT]}, - 159: {"Connection rate exceeded": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 160: {"Maximum connect time": - [PacketTypes.DISCONNECT]}, - 161: {"Subscription identifiers not supported": - [PacketTypes.SUBACK, PacketTypes.DISCONNECT]}, - 162: {"Wildcard subscription not supported": - [PacketTypes.SUBACK, PacketTypes.DISCONNECT]}, + 153: {"Payload format invalid": [PacketTypes.PUBACK, PacketTypes.PUBREC, PacketTypes.DISCONNECT]}, + 154: {"Retain not supported": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 155: {"QoS not supported": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 156: {"Use another server": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 157: {"Server moved": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 158: {"Shared subscription not supported": [PacketTypes.SUBACK, PacketTypes.DISCONNECT]}, + 159: {"Connection rate exceeded": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 160: {"Maximum connect time": [PacketTypes.DISCONNECT]}, + 161: {"Subscription identifiers not supported": [PacketTypes.SUBACK, PacketTypes.DISCONNECT]}, + 162: {"Wildcard subscription not supported": [PacketTypes.SUBACK, PacketTypes.DISCONNECT]}, } if identifier == -1: if packetType == PacketTypes.DISCONNECT and aName == "Success": @@ -165,8 +161,7 @@ def unpack(self, buffer): return 1 def getName(self): - """Returns the reason code name corresponding to the numeric value which is set. - """ + """Returns the reason code name corresponding to the numeric value which is set.""" return self.__getName__(self.packetType, self.value) def __eq__(self, other): diff --git a/src/paho/mqtt/subscribe.py b/src/paho/mqtt/subscribe.py index 955dfa13..7a1db27c 100644 --- a/src/paho/mqtt/subscribe.py +++ b/src/paho/mqtt/subscribe.py @@ -28,11 +28,12 @@ def _on_connect_v5(client, userdata, flags, rc, properties): if rc != 0: raise mqtt.MQTTException(paho.connack_string(rc)) - if isinstance(userdata['topics'], list): - for topic in userdata['topics']: - client.subscribe(topic, userdata['qos']) + if isinstance(userdata["topics"], list): + for topic in userdata["topics"]: + client.subscribe(topic, userdata["qos"]) else: - client.subscribe(userdata['topics'], userdata['qos']) + client.subscribe(userdata["topics"], userdata["qos"]) + def _on_connect(client, userdata, flags, rc): """Internal v5 callback""" @@ -41,35 +42,48 @@ def _on_connect(client, userdata, flags, rc): def _on_message_callback(client, userdata, message): """Internal callback""" - userdata['callback'](client, userdata['userdata'], message) + userdata["callback"](client, userdata["userdata"], message) def _on_message_simple(client, userdata, message): """Internal callback""" - if userdata['msg_count'] == 0: + if userdata["msg_count"] == 0: return # Don't process stale retained messages if 'retained' was false - if message.retain and not userdata['retained']: + if message.retain and not userdata["retained"]: return - userdata['msg_count'] = userdata['msg_count'] - 1 + userdata["msg_count"] = userdata["msg_count"] - 1 - if userdata['messages'] is None and userdata['msg_count'] == 0: - userdata['messages'] = message + if userdata["messages"] is None and userdata["msg_count"] == 0: + userdata["messages"] = message client.disconnect() return - userdata['messages'].append(message) - if userdata['msg_count'] == 0: + userdata["messages"].append(message) + if userdata["msg_count"] == 0: client.disconnect() -def callback(callback, topics, qos=0, userdata=None, hostname="localhost", - port=1883, client_id="", keepalive=60, will=None, auth=None, - tls=None, protocol=paho.MQTTv311, transport="tcp", - clean_session=True, proxy_args=None): +def callback( + callback, + topics, + qos=0, + userdata=None, + hostname="localhost", + port=1883, + client_id="", + keepalive=60, + will=None, + auth=None, + tls=None, + protocol=paho.MQTTv311, + transport="tcp", + clean_session=True, + proxy_args=None, +): """Subscribe to a list of topics and process them in a callback function. This function creates an MQTT client, connects to a broker and subscribes @@ -134,17 +148,11 @@ def callback(callback, topics, qos=0, userdata=None, hostname="localhost", """ if qos < 0 or qos > 2: - raise ValueError('qos must be in the range 0-2') + raise ValueError("qos must be in the range 0-2") - callback_userdata = { - 'callback':callback, - 'topics':topics, - 'qos':qos, - 'userdata':userdata} + callback_userdata = {"callback": callback, "topics": topics, "qos": qos, "userdata": userdata} - client = paho.Client(client_id=client_id, userdata=callback_userdata, - protocol=protocol, transport=transport, - clean_session=clean_session) + client = paho.Client(client_id=client_id, userdata=callback_userdata, protocol=protocol, transport=transport, clean_session=clean_session) client.on_message = _on_message_callback if protocol == mqtt.client.MQTTv5: client.on_connect = _on_connect_v5 @@ -155,20 +163,19 @@ def callback(callback, topics, qos=0, userdata=None, hostname="localhost", client.proxy_set(**proxy_args) if auth: - username = auth.get('username') + username = auth.get("username") if username: - password = auth.get('password') + password = auth.get("password") client.username_pw_set(username, password) else: - raise KeyError("The 'username' key was not found, this is " - "required for auth") + raise KeyError("The 'username' key was not found, this is " "required for auth") if will is not None: client.will_set(**will) if tls is not None: if isinstance(tls, dict): - insecure = tls.pop('insecure', False) + insecure = tls.pop("insecure", False) client.tls_set(**tls) if insecure: # Must be set *after* the `client.tls_set()` call since it sets @@ -182,10 +189,23 @@ def callback(callback, topics, qos=0, userdata=None, hostname="localhost", client.loop_forever() -def simple(topics, qos=0, msg_count=1, retained=True, hostname="localhost", - port=1883, client_id="", keepalive=60, will=None, auth=None, - tls=None, protocol=paho.MQTTv311, transport="tcp", - clean_session=True, proxy_args=None): +def simple( + topics, + qos=0, + msg_count=1, + retained=True, + hostname="localhost", + port=1883, + client_id="", + keepalive=60, + will=None, + auth=None, + tls=None, + protocol=paho.MQTTv311, + transport="tcp", + clean_session=True, + proxy_args=None, +): """Subscribe to a list of topics and return msg_count messages. This function creates an MQTT client, connects to a broker and subscribes @@ -258,7 +278,7 @@ def simple(topics, qos=0, msg_count=1, retained=True, hostname="localhost", """ if msg_count < 1: - raise ValueError('msg_count must be > 0') + raise ValueError("msg_count must be > 0") # Set ourselves up to return a single message if msg_count == 1, or a list # if > 1. @@ -271,10 +291,8 @@ def simple(topics, qos=0, msg_count=1, retained=True, hostname="localhost", if protocol == paho.MQTTv5: clean_session = None - userdata = {'retained':retained, 'msg_count':msg_count, 'messages':messages} + userdata = {"retained": retained, "msg_count": msg_count, "messages": messages} - callback(_on_message_simple, topics, qos, userdata, hostname, port, - client_id, keepalive, will, auth, tls, protocol, transport, - clean_session, proxy_args) + callback(_on_message_simple, topics, qos, userdata, hostname, port, client_id, keepalive, will, auth, tls, protocol, transport, clean_session, proxy_args) - return userdata['messages'] + return userdata["messages"] diff --git a/src/paho/mqtt/subscribeoptions.py b/src/paho/mqtt/subscribeoptions.py index f56973ce..bf9a0302 100644 --- a/src/paho/mqtt/subscribeoptions.py +++ b/src/paho/mqtt/subscribeoptions.py @@ -17,7 +17,6 @@ """ - class MQTTException(Exception): pass @@ -38,8 +37,7 @@ class SubscribeOptions: """ # retain handling options - RETAIN_SEND_ON_SUBSCRIBE, RETAIN_SEND_IF_NEW_SUB, RETAIN_DO_NOT_SEND = range( - 0, 3) + RETAIN_SEND_ON_SUBSCRIBE, RETAIN_SEND_IF_NEW_SUB, RETAIN_DO_NOT_SEND = range(0, 3) def __init__(self, qos=0, noLocal=False, retainAsPublished=False, retainHandling=RETAIN_SEND_ON_SUBSCRIBE): """ @@ -49,8 +47,7 @@ def __init__(self, qos=0, noLocal=False, retainAsPublished=False, retainHandling retainHandling: RETAIN_SEND_ON_SUBSCRIBE, RETAIN_SEND_IF_NEW_SUB or RETAIN_DO_NOT_SEND RETAIN_SEND_ON_SUBSCRIBE is the default and corresponds to MQTT v3.1.1 behavior. """ - object.__setattr__(self, "names", - ["QoS", "noLocal", "retainAsPublished", "retainHandling"]) + object.__setattr__(self, "names", ["QoS", "noLocal", "retainAsPublished", "retainHandling"]) self.QoS = qos # bits 0,1 self.noLocal = noLocal # bit 2 self.retainAsPublished = retainAsPublished # bit 3 @@ -62,8 +59,7 @@ def __init__(self, qos=0, noLocal=False, retainAsPublished=False, retainHandling def __setattr__(self, name, value): if name not in self.names: - raise MQTTException( - f"{name} Attribute name must be one of {self.names}") + raise MQTTException(f"{name} Attribute name must be one of {self.names}") object.__setattr__(self, name, value) def pack(self): @@ -73,16 +69,15 @@ def pack(self): raise AssertionError(f"QoS should be 0, 1 or 2, not {self.QoS}") noLocal = 1 if self.noLocal else 0 retainAsPublished = 1 if self.retainAsPublished else 0 - data = [(self.retainHandling << 4) | (retainAsPublished << 3) | - (noLocal << 2) | self.QoS] + data = [(self.retainHandling << 4) | (retainAsPublished << 3) | (noLocal << 2) | self.QoS] return bytes(data) def unpack(self, buffer): b0 = buffer[0] - self.retainHandling = ((b0 >> 4) & 0x03) + self.retainHandling = (b0 >> 4) & 0x03 self.retainAsPublished = True if ((b0 >> 3) & 0x01) == 1 else False self.noLocal = True if ((b0 >> 2) & 0x01) == 1 else False - self.QoS = (b0 & 0x03) + self.QoS = b0 & 0x03 if self.retainHandling not in (0, 1, 2): raise AssertionError(f"Retain handling should be 0, 1 or 2, not {self.retainHandling}") if self.QoS not in (0, 1, 2): @@ -93,9 +88,17 @@ def __repr__(self): return str(self) def __str__(self): - return "{QoS="+str(self.QoS)+", noLocal="+str(self.noLocal) +\ - ", retainAsPublished="+str(self.retainAsPublished) +\ - ", retainHandling="+str(self.retainHandling)+"}" + return ( + "{QoS=" + + str(self.QoS) + + ", noLocal=" + + str(self.noLocal) + + ", retainAsPublished=" + + str(self.retainAsPublished) + + ", retainHandling=" + + str(self.retainHandling) + + "}" + ) def json(self): data = { diff --git a/tests/debug_helpers.py b/tests/debug_helpers.py index 54b96368..94c92c46 100644 --- a/tests/debug_helpers.py +++ b/tests/debug_helpers.py @@ -8,7 +8,7 @@ def dump_packet(prefix: str, data: bytes) -> None: data = to_string(data) print(prefix, ": ", data, sep="") except struct.error: - data = binascii.b2a_hex(data).decode('utf8') + data = binascii.b2a_hex(data).decode("utf8") print(prefix, " (not decoded): 0x", data, sep="") @@ -23,7 +23,7 @@ def remaining_length(packet: bytes) -> Tuple[bytes, int]: rl += (byte & 127) * mult mult *= 128 if byte & 128 == 0: - packet = packet[i + 1:] + packet = packet[i + 1 :] break return (packet, rl) @@ -36,7 +36,7 @@ def to_hex_string(packet: bytes) -> str: s = "" while len(packet) > 0: packet0 = struct.unpack("!B", packet[0]) - s = s+hex(packet0[0]) + " " + s = s + hex(packet0[0]) + " " packet = packet[1:] return s @@ -46,7 +46,7 @@ def to_string(packet: bytes) -> str: if not packet: return "" - packet0 = struct.unpack("!B%ds" % (len(packet)-1), bytes(packet)) + packet0 = struct.unpack("!B%ds" % (len(packet) - 1), bytes(packet)) packet0 = packet0[0] cmd = packet0 & 0xF0 if cmd == 0x00: @@ -55,29 +55,29 @@ def to_string(packet: bytes) -> str: elif cmd == 0x10: # CONNECT (packet, rl) = remaining_length(packet) - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (slen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(slen) + 'sBBH' + str(len(packet) - slen - 4) + 's' + pack_format = "!" + str(slen) + "sBBH" + str(len(packet) - slen - 4) + "s" (protocol, proto_ver, flags, keepalive, packet) = struct.unpack(pack_format, packet) - kind = ("clean-session" if flags & 2 else "durable") + kind = "clean-session" if flags & 2 else "durable" s = f"CONNECT, proto={protocol}{proto_ver}, keepalive={keepalive}, {kind}" - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (slen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(slen) + 's' + str(len(packet) - slen) + 's' + pack_format = "!" + str(slen) + "s" + str(len(packet) - slen) + "s" (client_id, packet) = struct.unpack(pack_format, packet) s = s + ", id=" + str(client_id) if flags & 4: - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (slen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(slen) + 's' + str(len(packet) - slen) + 's' + pack_format = "!" + str(slen) + "s" + str(len(packet) - slen) + "s" (will_topic, packet) = struct.unpack(pack_format, packet) s = s + ", will-topic=" + str(will_topic) - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (slen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(slen) + 's' + str(len(packet) - slen) + 's' + pack_format = "!" + str(slen) + "s" + str(len(packet) - slen) + "s" (will_message, packet) = struct.unpack(pack_format, packet) s = s + ", will-message=" + will_message @@ -85,16 +85,16 @@ def to_string(packet: bytes) -> str: s = s + ", will-retain=" + str((flags & 32) >> 5) if flags & 128: - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (slen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(slen) + 's' + str(len(packet) - slen) + 's' + pack_format = "!" + str(slen) + "s" + str(len(packet) - slen) + "s" (username, packet) = struct.unpack(pack_format, packet) s = s + ", username=" + str(username) if flags & 64: - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (slen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(slen) + 's' + str(len(packet) - slen) + 's' + pack_format = "!" + str(slen) + "s" + str(len(packet) - slen) + "s" (password, packet) = struct.unpack(pack_format, packet) s = s + ", password=" + str(password) @@ -105,11 +105,11 @@ def to_string(packet: bytes) -> str: elif cmd == 0x20: # CONNACK if len(packet) == 4: - (cmd, rl, resv, rc) = struct.unpack('!BBBB', packet) - return "CONNACK, rl="+str(rl)+", res="+str(resv)+", rc="+str(rc) + (cmd, rl, resv, rc) = struct.unpack("!BBBB", packet) + return "CONNACK, rl=" + str(rl) + ", res=" + str(resv) + ", rc=" + str(rc) elif len(packet) == 5: - (cmd, rl, flags, reason_code, proplen) = struct.unpack('!BBBBB', packet) - return "CONNACK, rl="+str(rl)+", flags="+str(flags)+", rc="+str(reason_code)+", proplen="+str(proplen) + (cmd, rl, flags, reason_code, proplen) = struct.unpack("!BBBBB", packet) + return "CONNACK, rl=" + str(rl) + ", flags=" + str(flags) + ", rc=" + str(reason_code) + ", proplen=" + str(proplen) else: return "CONNACK, (not decoded)" @@ -117,15 +117,15 @@ def to_string(packet: bytes) -> str: # PUBLISH dup = (packet0 & 0x08) >> 3 qos = (packet0 & 0x06) >> 1 - retain = (packet0 & 0x01) + retain = packet0 & 0x01 (packet, rl) = remaining_length(packet) - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (tlen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(tlen) + 's' + str(len(packet) - tlen) + 's' + pack_format = "!" + str(tlen) + "s" + str(len(packet) - tlen) + "s" (topic, packet) = struct.unpack(pack_format, packet) s = "PUBLISH, rl=" + str(rl) + ", topic=" + str(topic) + ", qos=" + str(qos) + ", retain=" + str(retain) + ", dup=" + str(dup) if qos > 0: - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (mid, packet) = struct.unpack(pack_format, packet) s = s + ", mid=" + str(mid) @@ -134,46 +134,46 @@ def to_string(packet: bytes) -> str: elif cmd == 0x40: # PUBACK if len(packet) == 5: - (cmd, rl, mid, reason_code) = struct.unpack('!BBHB', packet) - return "PUBACK, rl="+str(rl)+", mid="+str(mid)+", reason_code="+str(reason_code) + (cmd, rl, mid, reason_code) = struct.unpack("!BBHB", packet) + return "PUBACK, rl=" + str(rl) + ", mid=" + str(mid) + ", reason_code=" + str(reason_code) else: - (cmd, rl, mid) = struct.unpack('!BBH', packet) - return "PUBACK, rl="+str(rl)+", mid="+str(mid) + (cmd, rl, mid) = struct.unpack("!BBH", packet) + return "PUBACK, rl=" + str(rl) + ", mid=" + str(mid) elif cmd == 0x50: # PUBREC if len(packet) == 5: - (cmd, rl, mid, reason_code) = struct.unpack('!BBHB', packet) - return "PUBREC, rl="+str(rl)+", mid="+str(mid)+", reason_code="+str(reason_code) + (cmd, rl, mid, reason_code) = struct.unpack("!BBHB", packet) + return "PUBREC, rl=" + str(rl) + ", mid=" + str(mid) + ", reason_code=" + str(reason_code) else: - (cmd, rl, mid) = struct.unpack('!BBH', packet) - return "PUBREC, rl="+str(rl)+", mid="+str(mid) + (cmd, rl, mid) = struct.unpack("!BBH", packet) + return "PUBREC, rl=" + str(rl) + ", mid=" + str(mid) elif cmd == 0x60: # PUBREL dup = (packet0 & 0x08) >> 3 - (cmd, rl, mid) = struct.unpack('!BBH', packet) + (cmd, rl, mid) = struct.unpack("!BBH", packet) return "PUBREL, rl=" + str(rl) + ", mid=" + str(mid) + ", dup=" + str(dup) elif cmd == 0x70: # PUBCOMP - (cmd, rl, mid) = struct.unpack('!BBH', packet) + (cmd, rl, mid) = struct.unpack("!BBH", packet) return "PUBCOMP, rl=" + str(rl) + ", mid=" + str(mid) elif cmd == 0x80: # SUBSCRIBE (packet, rl) = remaining_length(packet) - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (mid, packet) = struct.unpack(pack_format, packet) s = "SUBSCRIBE, rl=" + str(rl) + ", mid=" + str(mid) topic_index = 0 while len(packet) > 0: - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (tlen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(tlen) + 'sB' + str(len(packet) - tlen - 1) + 's' + pack_format = "!" + str(tlen) + "sB" + str(len(packet) - tlen - 1) + "s" (topic, qos, packet) = struct.unpack(pack_format, packet) s = s + ", topic" + str(topic_index) + "=" + str(topic) + "," + str(qos) return s elif cmd == 0x90: # SUBACK (packet, rl) = remaining_length(packet) - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (mid, packet) = struct.unpack(pack_format, packet) pack_format = "!" + "B" * len(packet) granted_qos = struct.unpack(pack_format, packet) @@ -185,39 +185,39 @@ def to_string(packet: bytes) -> str: elif cmd == 0xA0: # UNSUBSCRIBE (packet, rl) = remaining_length(packet) - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (mid, packet) = struct.unpack(pack_format, packet) s = "UNSUBSCRIBE, rl=" + str(rl) + ", mid=" + str(mid) topic_index = 0 while len(packet) > 0: - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (tlen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(tlen) + 's' + str(len(packet) - tlen) + 's' + pack_format = "!" + str(tlen) + "s" + str(len(packet) - tlen) + "s" (topic, packet) = struct.unpack(pack_format, packet) s = s + ", topic" + str(topic_index) + "=" + str(topic) return s elif cmd == 0xB0: # UNSUBACK - (cmd, rl, mid) = struct.unpack('!BBH', packet) + (cmd, rl, mid) = struct.unpack("!BBH", packet) return "UNSUBACK, rl=" + str(rl) + ", mid=" + str(mid) elif cmd == 0xC0: # PINGREQ - (cmd, rl) = struct.unpack('!BB', packet) + (cmd, rl) = struct.unpack("!BB", packet) return "PINGREQ, rl=" + str(rl) elif cmd == 0xD0: # PINGRESP - (cmd, rl) = struct.unpack('!BB', packet) + (cmd, rl) = struct.unpack("!BB", packet) return "PINGRESP, rl=" + str(rl) elif cmd == 0xE0: # DISCONNECT if len(packet) == 3: - (cmd, rl, reason_code) = struct.unpack('!BBB', packet) - return "DISCONNECT, rl="+str(rl)+", reason_code="+str(reason_code) + (cmd, rl, reason_code) = struct.unpack("!BBB", packet) + return "DISCONNECT, rl=" + str(rl) + ", reason_code=" + str(reason_code) else: - (cmd, rl) = struct.unpack('!BB', packet) - return "DISCONNECT, rl="+str(rl) + (cmd, rl) = struct.unpack("!BB", packet) + return "DISCONNECT, rl=" + str(rl) elif cmd == 0xF0: # AUTH - (cmd, rl) = struct.unpack('!BB', packet) - return "AUTH, rl="+str(rl) + (cmd, rl) = struct.unpack("!BB", packet) + return "AUTH, rl=" + str(rl) raise ValueError(f"Unknown packet type {cmd}") diff --git a/tests/lib/clients/01-asyncio.py b/tests/lib/clients/01-asyncio.py index eeab4433..cfd99a6b 100644 --- a/tests/lib/clients/01-asyncio.py +++ b/tests/lib/clients/01-asyncio.py @@ -5,7 +5,7 @@ from tests.paho_test import get_test_server_port -client_id = 'asyncio-test' +client_id = "asyncio-test" class AsyncioHelper: @@ -80,10 +80,11 @@ def on_disconnect(client, userdata, rc): _aioh = AsyncioHelper(loop, client) - client.connect('localhost', get_test_server_port(), 60) + client.connect("localhost", get_test_server_port(), 60) client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) await disconnected -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/tests/lib/clients/01-unpwd-unicode-set.py b/tests/lib/clients/01-unpwd-unicode-set.py index b0d12f9d..1f49e1fe 100644 --- a/tests/lib/clients/01-unpwd-unicode-set.py +++ b/tests/lib/clients/01-unpwd-unicode-set.py @@ -1,4 +1,3 @@ - import paho.mqtt.client as mqtt from tests.paho_test import get_test_server_port, loop_until_keyboard_interrupt diff --git a/tests/lib/clients/03-publish-fill-inflight.py b/tests/lib/clients/03-publish-fill-inflight.py index f954dd13..ea427a9f 100644 --- a/tests/lib/clients/03-publish-fill-inflight.py +++ b/tests/lib/clients/03-publish-fill-inflight.py @@ -22,10 +22,12 @@ def on_connect(mqttc, obj, flags, rc): for i in range(12): mqttc.publish("topic", expected_payload(i), qos=1) + def on_disconnect(mqttc, rc, properties): logging.info("disconnected") mqttc.reconnect() + logging.basicConfig(level=logging.DEBUG) logging.info(str(mqtt)) mqttc = mqtt.Client("publish-qos1-test") diff --git a/tests/lib/conftest.py b/tests/lib/conftest.py index bb2f4278..f6415071 100644 --- a/tests/lib/conftest.py +++ b/tests/lib/conftest.py @@ -53,12 +53,15 @@ def starter(name: str, expected_returncode: int = 0) -> None: PAHO_SSL_PATH=str(ssl_path), PYTHONPATH=f"{tests_path}{os.pathsep}{os.environ.get('PYTHONPATH', '')}", ) - assert 'PAHO_SERVER_PORT' in env, "PAHO_SERVER_PORT must be set in the environment when starting a client" + assert "PAHO_SERVER_PORT" in env, "PAHO_SERVER_PORT must be set in the environment when starting a client" # TODO: it would be nice to run this under `coverage` too! - proc = subprocess.Popen([ # noqa: S603 - sys.executable, - str(client_path), - ], env=env) + proc = subprocess.Popen( + [ # noqa: S603 + sys.executable, + str(client_path), + ], + env=env, + ) def fin(): stop_process(proc) diff --git a/tests/lib/test_01_reconnect_on_failure.py b/tests/lib/test_01_reconnect_on_failure.py index 8deb6539..832f5ee2 100644 --- a/tests/lib/test_01_reconnect_on_failure.py +++ b/tests/lib/test_01_reconnect_on_failure.py @@ -7,8 +7,7 @@ connack_packet_ok = paho_test.gen_connack(rc=0) connack_packet_failure = paho_test.gen_connack(rc=1) # CONNACK_REFUSED_PROTOCOL_VERSION -publish_packet = paho_test.gen_publish( - "reconnect/test", qos=0, payload="message") +publish_packet = paho_test.gen_publish("reconnect/test", qos=0, payload="message") @pytest.mark.parametrize("ok_code", [False, True]) diff --git a/tests/lib/test_01_unpwd_empty_password_set.py b/tests/lib/test_01_unpwd_empty_password_set.py index 225d72bc..d19b1709 100644 --- a/tests/lib/test_01_unpwd_empty_password_set.py +++ b/tests/lib/test_01_unpwd_empty_password_set.py @@ -6,8 +6,7 @@ import tests.paho_test as paho_test -connect_packet = paho_test.gen_connect( - "01-unpwd-set", keepalive=60, username="uname", password="") +connect_packet = paho_test.gen_connect("01-unpwd-set", keepalive=60, username="uname", password="") def test_01_unpwd_empty_password_set(server_socket, start_client): diff --git a/tests/lib/test_01_unpwd_empty_set.py b/tests/lib/test_01_unpwd_empty_set.py index 8c51c22a..d6400f3e 100644 --- a/tests/lib/test_01_unpwd_empty_set.py +++ b/tests/lib/test_01_unpwd_empty_set.py @@ -6,8 +6,7 @@ import tests.paho_test as paho_test -connect_packet = paho_test.gen_connect( - "01-unpwd-set", keepalive=60, username="", password='') +connect_packet = paho_test.gen_connect("01-unpwd-set", keepalive=60, username="", password="") def test_01_unpwd_empty_set(server_socket, start_client): diff --git a/tests/lib/test_01_unpwd_set.py b/tests/lib/test_01_unpwd_set.py index 38834ab9..f86a4060 100644 --- a/tests/lib/test_01_unpwd_set.py +++ b/tests/lib/test_01_unpwd_set.py @@ -6,8 +6,7 @@ import tests.paho_test as paho_test -connect_packet = paho_test.gen_connect( - "01-unpwd-set", keepalive=60, username="uname", password=";'[08gn=#") +connect_packet = paho_test.gen_connect("01-unpwd-set", keepalive=60, username="uname", password=";'[08gn=#") def test_01_unpwd_set(server_socket, start_client): diff --git a/tests/lib/test_01_will_set.py b/tests/lib/test_01_will_set.py index 55b55a25..12431122 100644 --- a/tests/lib/test_01_will_set.py +++ b/tests/lib/test_01_will_set.py @@ -9,8 +9,8 @@ import tests.paho_test as paho_test connect_packet = paho_test.gen_connect( - "01-will-set", keepalive=60, will_topic="topic/on/unexpected/disconnect", - will_qos=1, will_retain=True, will_payload="will message") + "01-will-set", keepalive=60, will_topic="topic/on/unexpected/disconnect", will_qos=1, will_retain=True, will_payload="will message" +) def test_01_will_set(server_socket, start_client): diff --git a/tests/lib/test_01_will_unpwd_set.py b/tests/lib/test_01_will_unpwd_set.py index 95c0517f..a38b1b67 100644 --- a/tests/lib/test_01_will_unpwd_set.py +++ b/tests/lib/test_01_will_unpwd_set.py @@ -10,8 +10,12 @@ connect_packet = paho_test.gen_connect( "01-will-unpwd-set", - keepalive=60, username="oibvvwqw", password="#'^2hg9a&nm38*us", - will_topic="will-topic", will_qos=2, will_payload="will message", + keepalive=60, + username="oibvvwqw", + password="#'^2hg9a&nm38*us", + will_topic="will-topic", + will_qos=2, + will_payload="will message", ) diff --git a/tests/lib/test_03_publish_b2c_qos1.py b/tests/lib/test_03_publish_b2c_qos1.py index 4666900d..8082c55b 100644 --- a/tests/lib/test_03_publish_b2c_qos1.py +++ b/tests/lib/test_03_publish_b2c_qos1.py @@ -17,8 +17,7 @@ disconnect_packet = paho_test.gen_disconnect() mid = 123 -publish_packet = paho_test.gen_publish( - "pub/qos1/receive", qos=1, mid=mid, payload="message") +publish_packet = paho_test.gen_publish("pub/qos1/receive", qos=1, mid=mid, payload="message") puback_packet = paho_test.gen_puback(mid) diff --git a/tests/lib/test_03_publish_c2b_qos1_disconnect.py b/tests/lib/test_03_publish_c2b_qos1_disconnect.py index 10daca6a..500765dc 100644 --- a/tests/lib/test_03_publish_c2b_qos1_disconnect.py +++ b/tests/lib/test_03_publish_c2b_qos1_disconnect.py @@ -4,17 +4,17 @@ import tests.paho_test as paho_test connect_packet = paho_test.gen_connect( - "publish-qos1-test", keepalive=60, clean_session=False, + "publish-qos1-test", + keepalive=60, + clean_session=False, ) connack_packet = paho_test.gen_connack(rc=0) disconnect_packet = paho_test.gen_disconnect() mid = 1 -publish_packet = paho_test.gen_publish( - "pub/qos1/test", qos=1, mid=mid, payload="message") -publish_packet_dup = paho_test.gen_publish( - "pub/qos1/test", qos=1, mid=mid, payload="message", dup=True) +publish_packet = paho_test.gen_publish("pub/qos1/test", qos=1, mid=mid, payload="message") +publish_packet_dup = paho_test.gen_publish("pub/qos1/test", qos=1, mid=mid, payload="message", dup=True) puback_packet = paho_test.gen_puback(mid) diff --git a/tests/lib/test_03_publish_c2b_qos2_disconnect.py b/tests/lib/test_03_publish_c2b_qos2_disconnect.py index 15b1d496..4c7ab85d 100644 --- a/tests/lib/test_03_publish_c2b_qos2_disconnect.py +++ b/tests/lib/test_03_publish_c2b_qos2_disconnect.py @@ -4,17 +4,17 @@ import tests.paho_test as paho_test connect_packet = paho_test.gen_connect( - "publish-qos2-test", keepalive=60, clean_session=False, + "publish-qos2-test", + keepalive=60, + clean_session=False, ) connack_packet = paho_test.gen_connack(rc=0) disconnect_packet = paho_test.gen_disconnect() mid = 1 -publish_packet = paho_test.gen_publish( - "pub/qos2/test", qos=2, mid=mid, payload="message") -publish_dup_packet = paho_test.gen_publish( - "pub/qos2/test", qos=2, mid=mid, payload="message", dup=True) +publish_packet = paho_test.gen_publish("pub/qos2/test", qos=2, mid=mid, payload="message") +publish_dup_packet = paho_test.gen_publish("pub/qos2/test", qos=2, mid=mid, payload="message", dup=True) pubrec_packet = paho_test.gen_pubrec(mid) pubrel_packet = paho_test.gen_pubrel(mid) pubcomp_packet = paho_test.gen_pubcomp(mid) diff --git a/tests/lib/test_03_publish_fill_inflight.py b/tests/lib/test_03_publish_fill_inflight.py index 697c896a..1a65e856 100644 --- a/tests/lib/test_03_publish_fill_inflight.py +++ b/tests/lib/test_03_publish_fill_inflight.py @@ -22,6 +22,7 @@ def expected_payload(i: int) -> bytes: return f"message{i}" + connect_packet = paho_test.gen_connect("publish-qos1-test", keepalive=60) connack_packet = paho_test.gen_connack(rc=0) @@ -29,7 +30,10 @@ def expected_payload(i: int) -> bytes: first_connection_publishs = [ paho_test.gen_publish( - "topic", qos=1, mid=i+1, payload=expected_payload(i), + "topic", + qos=1, + mid=i + 1, + payload=expected_payload(i), ) for i in range(10) ] @@ -39,14 +43,15 @@ def expected_payload(i: int) -> bytes: # Currently on reconnection client will do two wrong thing: # * it sent more than max_inflight packet # * it re-send message both with mid = old_mid + 12 AND with mid = old_mid & dup=1 - "topic", qos=1, mid=i+13, payload=expected_payload(i), + "topic", + qos=1, + mid=i + 13, + payload=expected_payload(i), ) for i in range(12) ] -second_connection_pubacks = [ - paho_test.gen_puback(i+13) - for i in range(12) -] +second_connection_pubacks = [paho_test.gen_puback(i + 13) for i in range(12)] + @pytest.mark.xfail def test_03_publish_fill_inflight(server_socket, start_client): @@ -87,4 +92,3 @@ def test_03_publish_fill_inflight(server_socket, start_client): paho_test.expect_packet(conn, "publish", second_connection_publishs[11]) paho_test.expect_no_packet(conn, 0.5) - diff --git a/tests/lib/test_03_publish_helper_qos0.py b/tests/lib/test_03_publish_helper_qos0.py index b1c57d90..dd4ba400 100644 --- a/tests/lib/test_03_publish_helper_qos0.py +++ b/tests/lib/test_03_publish_helper_qos0.py @@ -14,13 +14,12 @@ import tests.paho_test as paho_test connect_packet = paho_test.gen_connect( - "publish-helper-qos0-test", keepalive=60, + "publish-helper-qos0-test", + keepalive=60, ) connack_packet = paho_test.gen_connack(rc=0) -publish_packet = paho_test.gen_publish( - "pub/qos0/test", qos=0, payload="message" -) +publish_packet = paho_test.gen_publish("pub/qos0/test", qos=0, payload="message") disconnect_packet = paho_test.gen_disconnect() diff --git a/tests/lib/test_03_publish_helper_qos0_v5.py b/tests/lib/test_03_publish_helper_qos0_v5.py index ab950778..744a0e58 100644 --- a/tests/lib/test_03_publish_helper_qos0_v5.py +++ b/tests/lib/test_03_publish_helper_qos0_v5.py @@ -13,14 +13,10 @@ import tests.paho_test as paho_test -connect_packet = paho_test.gen_connect( - "publish-helper-qos0-test", keepalive=60, proto_ver=5, properties=None -) +connect_packet = paho_test.gen_connect("publish-helper-qos0-test", keepalive=60, proto_ver=5, properties=None) connack_packet = paho_test.gen_connack(rc=0, proto_ver=5) -publish_packet = paho_test.gen_publish( - "pub/qos0/test", qos=0, payload="message", proto_ver=5 -) +publish_packet = paho_test.gen_publish("pub/qos0/test", qos=0, payload="message", proto_ver=5) disconnect_packet = paho_test.gen_disconnect() diff --git a/tests/lib/test_03_publish_helper_qos1_disconnect.py b/tests/lib/test_03_publish_helper_qos1_disconnect.py index f73462c3..c18855ea 100644 --- a/tests/lib/test_03_publish_helper_qos1_disconnect.py +++ b/tests/lib/test_03_publish_helper_qos1_disconnect.py @@ -6,16 +6,18 @@ import tests.paho_test as paho_test connect_packet = paho_test.gen_connect( - "publish-helper-qos1-disconnect-test", keepalive=60, + "publish-helper-qos1-disconnect-test", + keepalive=60, ) connack_packet = paho_test.gen_connack(rc=0) mid = 1 -publish_packet = paho_test.gen_publish( - "pub/qos1/test", qos=1, mid=mid, payload="message" -) +publish_packet = paho_test.gen_publish("pub/qos1/test", qos=1, mid=mid, payload="message") publish_packet_dup = paho_test.gen_publish( - "pub/qos1/test", qos=1, mid=mid, payload="message", + "pub/qos1/test", + qos=1, + mid=mid, + payload="message", dup=True, ) puback_packet = paho_test.gen_puback(mid) diff --git a/tests/lib/test_04_retain_qos0.py b/tests/lib/test_04_retain_qos0.py index dee6b099..80b7afff 100644 --- a/tests/lib/test_04_retain_qos0.py +++ b/tests/lib/test_04_retain_qos0.py @@ -7,8 +7,7 @@ connect_packet = paho_test.gen_connect("retain-qos0-test", keepalive=60) connack_packet = paho_test.gen_connack(rc=0) -publish_packet = paho_test.gen_publish( - "retain/qos0/test", qos=0, payload="retained message", retain=True) +publish_packet = paho_test.gen_publish("retain/qos0/test", qos=0, payload="retained message", retain=True) def test_04_retain_qos0(server_socket, start_client): diff --git a/tests/mqtt5_props.py b/tests/mqtt5_props.py index f9be0d66..a69232d6 100644 --- a/tests/mqtt5_props.py +++ b/tests/mqtt5_props.py @@ -28,33 +28,40 @@ PROP_SUBSCRIPTION_ID_AVAILABLE = 41 PROP_SHARED_SUB_AVAILABLE = 42 + def gen_byte_prop(identifier, byte): - prop = struct.pack('BB', identifier, byte) + prop = struct.pack("BB", identifier, byte) return prop + def gen_uint16_prop(identifier, word): - prop = struct.pack('!BH', identifier, word) + prop = struct.pack("!BH", identifier, word) return prop + def gen_uint32_prop(identifier, word): - prop = struct.pack('!BI', identifier, word) + prop = struct.pack("!BI", identifier, word) return prop + def gen_string_prop(identifier, s): s = s.encode("utf-8") - prop = struct.pack(f'!BH{len(s)}s', identifier, len(s), s) + prop = struct.pack(f"!BH{len(s)}s", identifier, len(s), s) return prop + def gen_string_pair_prop(identifier, s1, s2): s1 = s1.encode("utf-8") s2 = s2.encode("utf-8") - prop = struct.pack(f'!BH{len(s1)}sH{len(s2)}s', identifier, len(s1), s1, len(s2), s2) + prop = struct.pack(f"!BH{len(s1)}sH{len(s2)}s", identifier, len(s1), s1, len(s2), s2) return prop + def gen_varint_prop(identifier, val): v = pack_varint(val) return struct.pack(f"!B{len(v)}s", identifier, v) + def pack_varint(varint): s = b"" while True: @@ -68,9 +75,9 @@ def pack_varint(varint): if varint == 0: return s + def prop_finalise(props): if props is None: return pack_varint(0) else: return pack_varint(len(props)) + props - diff --git a/tests/paho_test.py b/tests/paho_test.py index 4c77fb30..15d665bb 100644 --- a/tests/paho_test.py +++ b/tests/paho_test.py @@ -20,7 +20,7 @@ def bind_to_any_free_port(sock) -> int: Bind a socket to an available port on localhost, and return the port number. """ - sock.bind(('localhost', 0)) + sock.bind(("localhost", 0)) return sock.getsockname()[1] @@ -58,7 +58,7 @@ def expect_packet(sock, name, expected): packet_recvd = b"" try: while len(packet_recvd) < rlen: - data = sock.recv(rlen-len(packet_recvd)) + data = sock.recv(rlen - len(packet_recvd)) if len(data) == 0: break packet_recvd += data @@ -70,8 +70,7 @@ def expect_packet(sock, name, expected): def expect_no_packet(sock, delay=1): - """ expect that nothing is received within given delay - """ + """expect that nothing is received within given delay""" sock.settimeout(delay) try: previous_timeout = sock.gettimeout() @@ -97,17 +96,32 @@ def packet_matches(name, recvd, expected): return True -def gen_connect(client_id, clean_session=True, keepalive=60, username=None, password=None, will_topic=None, will_qos=0, will_retain=False, will_payload=b"", proto_ver=4, connect_reserved=False, properties=b"", will_properties=b"", session_expiry=-1): - if (proto_ver&0x7F) == 3 or proto_ver == 0: +def gen_connect( + client_id, + clean_session=True, + keepalive=60, + username=None, + password=None, + will_topic=None, + will_qos=0, + will_retain=False, + will_payload=b"", + proto_ver=4, + connect_reserved=False, + properties=b"", + will_properties=b"", + session_expiry=-1, +): + if (proto_ver & 0x7F) == 3 or proto_ver == 0: remaining_length = 12 - elif (proto_ver&0x7F) == 4 or proto_ver == 5: + elif (proto_ver & 0x7F) == 4 or proto_ver == 5: remaining_length = 10 else: raise ValueError if client_id is not None: client_id = client_id.encode("utf-8") - remaining_length = remaining_length + 2+len(client_id) + remaining_length = remaining_length + 2 + len(client_id) else: remaining_length = remaining_length + 2 @@ -130,7 +144,7 @@ def gen_connect(client_id, clean_session=True, keepalive=60, username=None, pass remaining_length += len(properties) if will_topic is not None: - will_topic = will_topic.encode('utf-8') + will_topic = will_topic.encode("utf-8") remaining_length = remaining_length + 2 + len(will_topic) + 2 + len(will_payload) connect_flags = connect_flags | 0x04 | ((will_qos & 0x03) << 3) if will_retain: @@ -140,19 +154,19 @@ def gen_connect(client_id, clean_session=True, keepalive=60, username=None, pass remaining_length += len(will_properties) if username is not None: - username = username.encode('utf-8') + username = username.encode("utf-8") remaining_length = remaining_length + 2 + len(username) connect_flags = connect_flags | 0x80 if password is not None: - password = password.encode('utf-8') + password = password.encode("utf-8") connect_flags = connect_flags | 0x40 remaining_length = remaining_length + 2 + len(password) rl = pack_remaining_length(remaining_length) packet = struct.pack("!B" + str(len(rl)) + "s", 0x10, rl) - if (proto_ver&0x7F) == 3 or proto_ver == 0: + if (proto_ver & 0x7F) == 3 or proto_ver == 0: packet = packet + struct.pack("!H6sBBH", len(b"MQIsdp"), b"MQIsdp", proto_ver, connect_flags, keepalive) - elif (proto_ver&0x7F) == 4 or proto_ver == 5: + elif (proto_ver & 0x7F) == 4 or proto_ver == 5: packet = packet + struct.pack("!H4sBBH", len(b"MQTT"), b"MQTT", proto_ver, connect_flags, keepalive) if proto_ver == 5: @@ -167,7 +181,7 @@ def gen_connect(client_id, clean_session=True, keepalive=60, username=None, pass packet += will_properties packet = packet + struct.pack("!H" + str(len(will_topic)) + "s", len(will_topic), will_topic) if len(will_payload) > 0: - packet = packet + struct.pack("!H" + str(len(will_payload)) + "s", len(will_payload), will_payload.encode('utf8')) + packet = packet + struct.pack("!H" + str(len(will_payload)) + "s", len(will_payload), will_payload.encode("utf8")) else: packet = packet + struct.pack("!H", 0) @@ -177,27 +191,32 @@ def gen_connect(client_id, clean_session=True, keepalive=60, username=None, pass packet = packet + struct.pack("!H" + str(len(password)) + "s", len(password), password) return packet + def gen_connack(flags=0, rc=0, proto_ver=4, properties=b"", property_helper=True): if proto_ver == 5: if property_helper: if properties is not None: - properties = mqtt5_props.gen_uint16_prop(mqtt5_props.PROP_TOPIC_ALIAS_MAXIMUM, 10) \ - + properties + mqtt5_props.gen_uint16_prop(mqtt5_props.PROP_RECEIVE_MAXIMUM, 20) + properties = ( + mqtt5_props.gen_uint16_prop(mqtt5_props.PROP_TOPIC_ALIAS_MAXIMUM, 10) + + properties + + mqtt5_props.gen_uint16_prop(mqtt5_props.PROP_RECEIVE_MAXIMUM, 20) + ) else: properties = b"" properties = mqtt5_props.prop_finalise(properties) - packet = struct.pack('!BBBB', 32, 2+len(properties), flags, rc) + properties + packet = struct.pack("!BBBB", 32, 2 + len(properties), flags, rc) + properties else: - packet = struct.pack('!BBBB', 32, 2, flags, rc) + packet = struct.pack("!BBBB", 32, 2, flags, rc) return packet + def gen_publish(topic, qos, payload=None, retain=False, dup=False, mid=0, proto_ver=4, properties=b""): if isinstance(topic, str): topic = topic.encode("utf-8") - rl = 2+len(topic) - pack_format = "H"+str(len(topic))+"s" + rl = 2 + len(topic) + pack_format = "H" + str(len(topic)) + "s" if qos > 0: rl = rl + 2 pack_format = pack_format + "H" @@ -206,7 +225,7 @@ def gen_publish(topic, qos, payload=None, retain=False, dup=False, mid=0, proto_ properties = mqtt5_props.prop_finalise(properties) rl += len(properties) # This will break if len(properties) > 127 - pack_format = pack_format + "%ds"%(len(properties)) + pack_format = pack_format + "%ds" % (len(properties)) if payload is not None: payload = payload.encode("utf-8") @@ -225,14 +244,15 @@ def gen_publish(topic, qos, payload=None, retain=False, dup=False, mid=0, proto_ if proto_ver == 5: if qos > 0: - return struct.pack("!B" + str(len(rlpacked))+"s" + pack_format, cmd, rlpacked, len(topic), topic, mid, properties, payload) + return struct.pack("!B" + str(len(rlpacked)) + "s" + pack_format, cmd, rlpacked, len(topic), topic, mid, properties, payload) else: - return struct.pack("!B" + str(len(rlpacked))+"s" + pack_format, cmd, rlpacked, len(topic), topic, properties, payload) + return struct.pack("!B" + str(len(rlpacked)) + "s" + pack_format, cmd, rlpacked, len(topic), topic, properties, payload) else: if qos > 0: - return struct.pack("!B" + str(len(rlpacked))+"s" + pack_format, cmd, rlpacked, len(topic), topic, mid, payload) + return struct.pack("!B" + str(len(rlpacked)) + "s" + pack_format, cmd, rlpacked, len(topic), topic, mid, payload) else: - return struct.pack("!B" + str(len(rlpacked))+"s" + pack_format, cmd, rlpacked, len(topic), topic, payload) + return struct.pack("!B" + str(len(rlpacked)) + "s" + pack_format, cmd, rlpacked, len(topic), topic, payload) + def _gen_command_with_mid(cmd, mid, proto_ver=4, reason_code=-1, properties=None): if proto_ver == 5 and (reason_code != -1 or properties is not None): @@ -240,29 +260,33 @@ def _gen_command_with_mid(cmd, mid, proto_ver=4, reason_code=-1, properties=None reason_code = 0 if properties is None: - return struct.pack('!BBHB', cmd, 3, mid, reason_code) + return struct.pack("!BBHB", cmd, 3, mid, reason_code) elif properties == "": - return struct.pack('!BBHBB', cmd, 4, mid, reason_code, 0) + return struct.pack("!BBHBB", cmd, 4, mid, reason_code, 0) else: properties = mqtt5_props.prop_finalise(properties) - pack_format = "!BBHB"+str(len(properties))+"s" - return struct.pack(pack_format, cmd, 2+1+len(properties), mid, reason_code, properties) + pack_format = "!BBHB" + str(len(properties)) + "s" + return struct.pack(pack_format, cmd, 2 + 1 + len(properties), mid, reason_code, properties) else: - return struct.pack('!BBH', cmd, 2, mid) + return struct.pack("!BBH", cmd, 2, mid) + def gen_puback(mid, proto_ver=4, reason_code=-1, properties=None): return _gen_command_with_mid(64, mid, proto_ver, reason_code, properties) + def gen_pubrec(mid, proto_ver=4, reason_code=-1, properties=None): return _gen_command_with_mid(80, mid, proto_ver, reason_code, properties) + def gen_pubrel(mid, dup=False, proto_ver=4, reason_code=-1, properties=None): if dup: - cmd = 96+8+2 + cmd = 96 + 8 + 2 else: - cmd = 96+2 + cmd = 96 + 2 return _gen_command_with_mid(cmd, mid, proto_ver, reason_code, properties) + def gen_pubcomp(mid, proto_ver=4, reason_code=-1, properties=None): return _gen_command_with_mid(112, mid, proto_ver, reason_code, properties) @@ -272,54 +296,56 @@ def gen_subscribe(mid, topic, qos, cmd=130, proto_ver=4, properties=b""): packet = struct.pack("!B", cmd) if proto_ver == 5: if properties == b"": - packet += pack_remaining_length(2+1+2+len(topic)+1) - pack_format = "!HBH"+str(len(topic))+"sB" + packet += pack_remaining_length(2 + 1 + 2 + len(topic) + 1) + pack_format = "!HBH" + str(len(topic)) + "sB" return packet + struct.pack(pack_format, mid, 0, len(topic), topic, qos) else: properties = mqtt5_props.prop_finalise(properties) - packet += pack_remaining_length(2+1+2+len(topic)+len(properties)) - pack_format = "!H"+str(len(properties))+"s"+"H"+str(len(topic))+"sB" + packet += pack_remaining_length(2 + 1 + 2 + len(topic) + len(properties)) + pack_format = "!H" + str(len(properties)) + "s" + "H" + str(len(topic)) + "sB" return packet + struct.pack(pack_format, mid, properties, len(topic), topic, qos) else: - packet += pack_remaining_length(2+2+len(topic)+1) - pack_format = "!HH"+str(len(topic))+"sB" + packet += pack_remaining_length(2 + 2 + len(topic) + 1) + pack_format = "!HH" + str(len(topic)) + "sB" return packet + struct.pack(pack_format, mid, len(topic), topic, qos) def gen_suback(mid, qos, proto_ver=4): if proto_ver == 5: - return struct.pack('!BBHBB', 144, 2+1+1, mid, 0, qos) + return struct.pack("!BBHBB", 144, 2 + 1 + 1, mid, 0, qos) else: - return struct.pack('!BBHB', 144, 2+1, mid, qos) + return struct.pack("!BBHB", 144, 2 + 1, mid, qos) + def gen_unsubscribe(mid, topic, cmd=162, proto_ver=4, properties=b""): topic = topic.encode("utf-8") if proto_ver == 5: if properties == b"": - pack_format = "!BBHBH"+str(len(topic))+"s" - return struct.pack(pack_format, cmd, 2+2+len(topic)+1, mid, 0, len(topic), topic) + pack_format = "!BBHBH" + str(len(topic)) + "s" + return struct.pack(pack_format, cmd, 2 + 2 + len(topic) + 1, mid, 0, len(topic), topic) else: properties = mqtt5_props.prop_finalise(properties) packet = struct.pack("!B", cmd) - l = 2+2+len(topic)+1+len(properties) # noqa: E741 + l = 2 + 2 + len(topic) + 1 + len(properties) # noqa: E741 packet += pack_remaining_length(l) - pack_format = "!HB"+str(len(properties))+"sH"+str(len(topic))+"s" + pack_format = "!HB" + str(len(properties)) + "sH" + str(len(topic)) + "s" packet += struct.pack(pack_format, mid, len(properties), properties, len(topic), topic) return packet else: - pack_format = "!BBHH"+str(len(topic))+"s" - return struct.pack(pack_format, cmd, 2+2+len(topic), mid, len(topic), topic) + pack_format = "!BBHH" + str(len(topic)) + "s" + return struct.pack(pack_format, cmd, 2 + 2 + len(topic), mid, len(topic), topic) + def gen_unsubscribe_multiple(mid, topics, proto_ver=4): packet = b"" remaining_length = 0 for t in topics: t = t.encode("utf-8") - remaining_length += 2+len(t) - packet += struct.pack("!H"+str(len(t))+"s", len(t), t) + remaining_length += 2 + len(t) + packet += struct.pack("!H" + str(len(t)) + "s", len(t), t) if proto_ver == 5: - remaining_length += 2+1 + remaining_length += 2 + 1 return struct.pack("!BBHB", 162, remaining_length, mid, 0) + packet else: @@ -327,44 +353,49 @@ def gen_unsubscribe_multiple(mid, topics, proto_ver=4): return struct.pack("!BBH", 162, remaining_length, mid) + packet + def gen_unsuback(mid, reason_code=0, proto_ver=4): if proto_ver == 5: if isinstance(reason_code, list): reason_code_count = len(reason_code) - p = struct.pack('!BBHB', 176, 3+reason_code_count, mid, 0) + p = struct.pack("!BBHB", 176, 3 + reason_code_count, mid, 0) for r in reason_code: - p += struct.pack('B', r) + p += struct.pack("B", r) return p else: - return struct.pack('!BBHBB', 176, 4, mid, 0, reason_code) + return struct.pack("!BBHBB", 176, 4, mid, 0, reason_code) else: - return struct.pack('!BBH', 176, 2, mid) + return struct.pack("!BBH", 176, 2, mid) + def gen_pingreq(): - return struct.pack('!BB', 192, 0) + return struct.pack("!BB", 192, 0) + def gen_pingresp(): - return struct.pack('!BB', 208, 0) + return struct.pack("!BB", 208, 0) def _gen_short(cmd, reason_code=-1, proto_ver=5, properties=None): if proto_ver == 5 and (reason_code != -1 or properties is not None): if reason_code == -1: - reason_code = 0 + reason_code = 0 if properties is None: - return struct.pack('!BBB', cmd, 1, reason_code) + return struct.pack("!BBB", cmd, 1, reason_code) elif properties == "": - return struct.pack('!BBBB', cmd, 2, reason_code, 0) + return struct.pack("!BBBB", cmd, 2, reason_code, 0) else: properties = mqtt5_props.prop_finalise(properties) - return struct.pack("!BBB", cmd, 1+len(properties), reason_code) + properties + return struct.pack("!BBB", cmd, 1 + len(properties), reason_code) + properties else: - return struct.pack('!BB', cmd, 0) + return struct.pack("!BB", cmd, 0) + def gen_disconnect(reason_code=-1, proto_ver=4, properties=None): return _gen_short(0xE0, reason_code, proto_ver, properties) + def gen_auth(reason_code=-1, properties=None): return _gen_short(0xF0, reason_code, 5, properties) @@ -421,4 +452,4 @@ def get_test_server_port() -> int: """ Get the port number for the test server. """ - return int(os.environ['PAHO_SERVER_PORT']) + return int(os.environ["PAHO_SERVER_PORT"]) diff --git a/tests/test_client.py b/tests/test_client.py index adfe6014..b3f0b1d4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -10,18 +10,20 @@ from tests.testsupport.broker import fake_broker # noqa: F401 -@pytest.mark.parametrize("proto_ver", [ - (client.MQTTv31), - (client.MQTTv311), -]) +@pytest.mark.parametrize( + "proto_ver", + [ + (client.MQTTv31), + (client.MQTTv311), + ], +) class Test_connect: """ Tests on connect/disconnect behaviour of the client """ def test_01_con_discon_success(self, proto_ver, fake_broker): - mqttc = client.Client( - "01-con-discon-success", protocol=proto_ver) + mqttc = client.Client("01-con-discon-success", protocol=proto_ver) def on_connect(mqttc, obj, flags, rc): assert rc == 0 @@ -35,9 +37,7 @@ def on_connect(mqttc, obj, flags, rc): try: fake_broker.start() - connect_packet = paho_test.gen_connect( - "01-con-discon-success", keepalive=60, - proto_ver=proto_ver) + connect_packet = paho_test.gen_connect("01-con-discon-success", keepalive=60, proto_ver=proto_ver) packet_in = fake_broker.receive_packet(1000) assert packet_in # Check connection was not closed assert packet_in == connect_packet @@ -59,8 +59,7 @@ def on_connect(mqttc, obj, flags, rc): assert not packet_in # Check connection is closed def test_01_con_failure_rc(self, proto_ver, fake_broker): - mqttc = client.Client( - "01-con-failure-rc", protocol=proto_ver) + mqttc = client.Client("01-con-failure-rc", protocol=proto_ver) def on_connect(mqttc, obj, flags, rc): assert rc == 1 @@ -73,9 +72,7 @@ def on_connect(mqttc, obj, flags, rc): try: fake_broker.start() - connect_packet = paho_test.gen_connect( - "01-con-failure-rc", keepalive=60, - proto_ver=proto_ver) + connect_packet = paho_test.gen_connect("01-con-failure-rc", keepalive=60, proto_ver=proto_ver) packet_in = fake_broker.receive_packet(1000) assert packet_in # Check connection was not closed assert packet_in == connect_packet @@ -93,7 +90,6 @@ def on_connect(mqttc, obj, flags, rc): class TestPublishBroker2Client: - def test_invalid_utf8_topic(self, fake_broker): mqttc = client.Client("client-id") @@ -140,7 +136,7 @@ def test_valid_utf8_topic_recv(self, fake_broker): mqttc = client.Client("client-id") # It should be non-ascii multi-bytes character - topic = unicodedata.lookup('SNOWMAN') + topic = unicodedata.lookup("SNOWMAN") def on_message(client, userdata, msg): assert msg.topic == topic @@ -164,9 +160,7 @@ def on_message(client, userdata, msg): assert count # Check connection was not closed assert count == len(connack_packet) - publish_packet = paho_test.gen_publish( - topic.encode('utf-8'), qos=0 - ) + publish_packet = paho_test.gen_publish(topic.encode("utf-8"), qos=0) count = fake_broker.send_packet(publish_packet) assert count # Check connection was not closed assert count == len(publish_packet) @@ -186,7 +180,7 @@ def test_valid_utf8_topic_publish(self, fake_broker): mqttc = client.Client("client-id") # It should be non-ascii multi-bytes character - topic = unicodedata.lookup('SNOWMAN') + topic = unicodedata.lookup("SNOWMAN") mqttc.connect_async("localhost", fake_broker.port) mqttc.loop_start() @@ -208,9 +202,7 @@ def test_valid_utf8_topic_publish(self, fake_broker): # Small sleep needed to avoid connection reset. time.sleep(0.3) - publish_packet = paho_test.gen_publish( - topic.encode('utf-8'), qos=0 - ) + publish_packet = paho_test.gen_publish(topic.encode("utf-8"), qos=0) packet_in = fake_broker.receive_packet(len(publish_packet)) assert packet_in # Check connection was not closed assert packet_in == publish_packet @@ -231,27 +223,27 @@ def test_valid_utf8_topic_publish(self, fake_broker): def test_message_callback(self, fake_broker): mqttc = client.Client("client-id") userdata = { - 'on_message': 0, - 'callback1': 0, - 'callback2': 0, + "on_message": 0, + "callback1": 0, + "callback2": 0, } mqttc.user_data_set(userdata) def on_message(client, userdata, msg): - assert msg.topic == 'topic/value' - userdata['on_message'] += 1 + assert msg.topic == "topic/value" + userdata["on_message"] += 1 def callback1(client, userdata, msg): - assert msg.topic == 'topic/callback/1' - userdata['callback1'] += 1 + assert msg.topic == "topic/callback/1" + userdata["callback1"] += 1 def callback2(client, userdata, msg): - assert msg.topic in ('topic/callback/3', 'topic/callback/1') - userdata['callback2'] += 1 + assert msg.topic in ("topic/callback/3", "topic/callback/1") + userdata["callback2"] += 1 mqttc.on_message = on_message - mqttc.message_callback_add('topic/callback/1', callback1) - mqttc.message_callback_add('topic/callback/+', callback2) + mqttc.message_callback_add("topic/callback/1", callback1) + mqttc.message_callback_add("topic/callback/+", callback2) mqttc.connect_async("localhost", fake_broker.port) mqttc.loop_start() @@ -284,7 +276,6 @@ def callback2(client, userdata, msg): assert count # Check connection was not closed assert count == len(publish_packet) - puback_packet = paho_test.gen_puback(mid=1) packet_in = fake_broker.receive_packet(len(puback_packet)) assert packet_in # Check connection was not closed @@ -313,6 +304,6 @@ def callback2(client, userdata, msg): packet_in = fake_broker.receive_packet(1) assert not packet_in # Check connection is closed - assert userdata['on_message'] == 1 - assert userdata['callback1'] == 1 - assert userdata['callback2'] == 2 + assert userdata["on_message"] == 1 + assert userdata["callback1"] == 1 + assert userdata["callback2"] == 2 diff --git a/tests/test_matcher.py b/tests/test_matcher.py index e2dc02a4..d8145229 100644 --- a/tests/test_matcher.py +++ b/tests/test_matcher.py @@ -7,30 +7,35 @@ class Test_client_function: Tests on topic_matches_sub function in the client module """ - @pytest.mark.parametrize("sub,topic", [ - ("foo/bar", "foo/bar"), - ("foo/+", "foo/bar"), - ("foo/+/baz", "foo/bar/baz"), - ("foo/+/#", "foo/bar/baz"), - ("A/B/+/#", "A/B/B/C"), - ("#", "foo/bar/baz"), - ("#", "/foo/bar"), - ("/#", "/foo/bar"), - ("$SYS/bar", "$SYS/bar"), - ]) + @pytest.mark.parametrize( + "sub,topic", + [ + ("foo/bar", "foo/bar"), + ("foo/+", "foo/bar"), + ("foo/+/baz", "foo/bar/baz"), + ("foo/+/#", "foo/bar/baz"), + ("A/B/+/#", "A/B/B/C"), + ("#", "foo/bar/baz"), + ("#", "/foo/bar"), + ("/#", "/foo/bar"), + ("$SYS/bar", "$SYS/bar"), + ], + ) def test_matching(self, sub, topic): assert client.topic_matches_sub(sub, topic) - - @pytest.mark.parametrize("sub,topic", [ - ("test/6/#", "test/3"), - ("foo/bar", "foo"), - ("foo/+", "foo/bar/baz"), - ("foo/+/baz", "foo/bar/bar"), - ("foo/+/#", "fo2/bar/baz"), - ("/#", "foo/bar"), - ("#", "$SYS/bar"), - ("$BOB/bar", "$SYS/bar"), - ]) + @pytest.mark.parametrize( + "sub,topic", + [ + ("test/6/#", "test/3"), + ("foo/bar", "foo"), + ("foo/+", "foo/bar/baz"), + ("foo/+/baz", "foo/bar/bar"), + ("foo/+/#", "fo2/bar/baz"), + ("/#", "foo/bar"), + ("#", "$SYS/bar"), + ("$BOB/bar", "$SYS/bar"), + ], + ) def test_not_matching(self, sub, topic): assert not client.topic_matches_sub(sub, topic) diff --git a/tests/test_mqttv5.py b/tests/test_mqttv5.py index 9fd61f2d..dbe996dc 100644 --- a/tests/test_mqttv5.py +++ b/tests/test_mqttv5.py @@ -31,7 +31,6 @@ class Callbacks: - def __init__(self): self.messages = [] self.publisheds = [] @@ -42,16 +41,13 @@ def __init__(self): self.conn_failures = [] def __str__(self): - return str(self.messages) + str(self.messagedicts) + str(self.publisheds) + \ - str(self.subscribeds) + \ - str(self.unsubscribeds) + str(self.disconnects) + return str(self.messages) + str(self.messagedicts) + str(self.publisheds) + str(self.subscribeds) + str(self.unsubscribeds) + str(self.disconnects) def clear(self): self.__init__() def on_connect(self, client, userdata, flags, reasonCode, properties): - self.connecteds.append({"userdata": userdata, "flags": flags, - "reasonCode": reasonCode, "properties": properties}) + self.connecteds.append({"userdata": userdata, "flags": flags, "reasonCode": reasonCode, "properties": properties}) def on_connect_fail(self, client, userdata): self.conn_failures.append({"userdata": userdata}) @@ -71,8 +67,7 @@ def wait_connected(self): return self.wait(self.connecteds) def on_disconnect(self, client, userdata, reasonCode, properties=None): - self.disconnecteds.append( - {"reasonCode": reasonCode, "properties": properties}) + self.disconnecteds.append({"reasonCode": reasonCode, "properties": properties}) def wait_disconnected(self): return self.wait(self.disconnecteds) @@ -87,15 +82,13 @@ def wait_published(self): return self.wait(self.publisheds) def on_subscribe(self, client, userdata, mid, reasonCodes, properties): - self.subscribeds.append({"mid": mid, "userdata": userdata, - "properties": properties, "reasonCodes": reasonCodes}) + self.subscribeds.append({"mid": mid, "userdata": userdata, "properties": properties, "reasonCodes": reasonCodes}) def wait_subscribed(self): return self.wait(self.subscribeds) def unsubscribed(self, client, userdata, mid, properties, reasonCodes): - self.unsubscribeds.append({"mid": mid, "userdata": userdata, - "properties": properties, "reasonCodes": reasonCodes}) + self.unsubscribeds.append({"mid": mid, "userdata": userdata, "properties": properties, "reasonCodes": reasonCodes}) def wait_unsubscribed(self): return self.wait(self.unsubscribeds) @@ -116,8 +109,7 @@ def register(self, client): def cleanRetained(port): callback = Callbacks() - curclient = paho.mqtt.client.Client(b"clean retained", - protocol=paho.mqtt.client.MQTTv5) + curclient = paho.mqtt.client.Client(b"clean retained", protocol=paho.mqtt.client.MQTTv5) curclient.loop_start() callback.register(curclient) curclient.connect(host="localhost", port=port) @@ -130,7 +122,7 @@ def cleanRetained(port): curclient.publish(message["message"].topic, b"", 0, retain=True) curclient.disconnect() curclient.loop_stop() - time.sleep(.1) + time.sleep(0.1) def cleanup(port): @@ -139,13 +131,12 @@ def cleanup(port): clientids = ("aclient", "bclient") for clientid in clientids: - curclient = paho.mqtt.client.Client(clientid.encode( - "utf-8"), protocol=paho.mqtt.client.MQTTv5) + curclient = paho.mqtt.client.Client(clientid.encode("utf-8"), protocol=paho.mqtt.client.MQTTv5) curclient.loop_start() curclient.connect(host="localhost", port=port, clean_start=True) - time.sleep(.1) + time.sleep(0.1) curclient.disconnect() - time.sleep(.1) + time.sleep(0.1) curclient.loop_stop() # clean retained messages @@ -154,7 +145,6 @@ def cleanup(port): class Test(unittest.TestCase): - @classmethod def setUpClass(cls): global callback, callback2, aclient, bclient @@ -187,8 +177,8 @@ def setUpClass(cls): callback = Callbacks() callback2 = Callbacks() - #aclient = mqtt_client.Client(b"\xEF\xBB\xBF" + "myclientid".encode("utf-8")) - #aclient = mqtt_client.Client("myclientid".encode("utf-8")) + # aclient = mqtt_client.Client(b"\xEF\xBB\xBF" + "myclientid".encode("utf-8")) + # aclient = mqtt_client.Client("myclientid".encode("utf-8")) aclient = paho.mqtt.client.Client(b"aclient", protocol=paho.mqtt.client.MQTTv5) callback.register(aclient) @@ -199,13 +189,14 @@ def setUpClass(cls): def tearDownClass(cls): # Another hack to stop the test broker... we rely on fact that it use a sockserver.TCPServer import mqtt.brokers + mqtt.brokers.listeners.TCPListeners.server.shutdown() cls._test_broker.join(5) def waitfor(self, queue, depth, limit): total = 0 while len(queue) < depth and total < limit: - interval = .5 + interval = 0.5 total += interval time.sleep(interval) @@ -224,7 +215,7 @@ def test_basic(self): aclient.publish(topics[0], b"qos 2", 2) i = 0 while len(callback.messages) < 3 and i < 10: - time.sleep(.2) + time.sleep(0.2) i += 1 self.assertEqual(len(callback.messages), 3) aclient.disconnect() @@ -244,7 +235,6 @@ def test_connect_fail(self): fclient.loop_stop() def test_retained_message(self): - publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.UserProperty = ("a", "2") publish_properties.UserProperty = ("c", "3") @@ -254,12 +244,9 @@ def test_retained_message(self): aclient.connect(host="localhost", port=self._test_broker_port) aclient.loop_start() response = callback.wait_connected() - aclient.publish(topics[1], b"qos 0", 0, - retain=True, properties=publish_properties) - aclient.publish(topics[2], b"qos 1", 1, - retain=True, properties=publish_properties) - aclient.publish(topics[3], b"qos 2", 2, - retain=True, properties=publish_properties) + aclient.publish(topics[1], b"qos 0", 0, retain=True, properties=publish_properties) + aclient.publish(topics[2], b"qos 1", 1, retain=True, properties=publish_properties) + aclient.publish(topics[3], b"qos 2", 2, retain=True, properties=publish_properties) # wait until those messages are published time.sleep(1) aclient.subscribe(wildtopics[5], options=SubscribeOptions(qos=2)) @@ -272,14 +259,11 @@ def test_retained_message(self): self.assertEqual(len(callback.messages), 3) userprops = callback.messages[0]["message"].properties.UserProperty - self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [ - ("c", "3"), ("a", "2")]], userprops) + self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [("c", "3"), ("a", "2")]], userprops) userprops = callback.messages[1]["message"].properties.UserProperty - self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [ - ("c", "3"), ("a", "2")]], userprops) + self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [("c", "3"), ("a", "2")]], userprops) userprops = callback.messages[2]["message"].properties.UserProperty - self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [ - ("c", "3"), ("a", "2")]], userprops) + self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [("c", "3"), ("a", "2")]], userprops) qoss = [callback.messages[i]["message"].qos for i in range(3)] self.assertTrue(1 in qoss and 2 in qoss and 0 in qoss, qoss) @@ -296,8 +280,7 @@ def test_will_message(self): will_properties.UserProperty = ("a", "2") will_properties.UserProperty = ("c", "3") - aclient.will_set(topics[2], payload=b"will message", - properties=will_properties) + aclient.will_set(topics[2], payload=b"will message", properties=will_properties) aclient.connect(host="localhost", port=self._test_broker_port, keepalive=2) aclient.loop_start() @@ -331,8 +314,7 @@ def test_zero_length_clientid(self): client0.connect(host="localhost", port=self._test_broker_port, clean_start=False) response = callback0.wait_connected() self.assertEqual(response["reasonCode"].getName(), "Success") - self.assertTrue( - len(response["properties"].AssignedClientIdentifier) > 0) + self.assertTrue(len(response["properties"].AssignedClientIdentifier) > 0) client0.disconnect() client0.loop_stop() @@ -342,21 +324,18 @@ def test_zero_length_clientid(self): client0.connect(host="localhost", port=self._test_broker_port) # should work response = callback0.wait_connected() self.assertEqual(response["reasonCode"].getName(), "Success") - self.assertTrue( - len(response["properties"].AssignedClientIdentifier) > 0) + self.assertTrue(len(response["properties"].AssignedClientIdentifier) > 0) client0.disconnect() client0.loop_stop() # when we supply a client id, we should not get one assigned - client0 = paho.mqtt.client.Client( - "client0", protocol=paho.mqtt.client.MQTTv5) + client0 = paho.mqtt.client.Client("client0", protocol=paho.mqtt.client.MQTTv5) callback0.register(client0) client0.loop_start() client0.connect(host="localhost", port=self._test_broker_port) # should work response = callback0.wait_connected() self.assertEqual(response["reasonCode"].getName(), "Success") - self.assertFalse( - hasattr(response["properties"], "AssignedClientIdentifier")) + self.assertFalse(hasattr(response["properties"], "AssignedClientIdentifier")) client0.disconnect() client0.loop_stop() @@ -366,8 +345,7 @@ def test_offline_message_queueing(self): ocallback = Callbacks() clientid = b"offline message queueing" - oclient = paho.mqtt.client.Client( - clientid, protocol=paho.mqtt.client.MQTTv5) + oclient = paho.mqtt.client.Client(clientid, protocol=paho.mqtt.client.MQTTv5) ocallback.register(oclient) connect_properties = Properties(PacketTypes.CONNECT) connect_properties.SessionExpiryInterval = 99999 @@ -389,8 +367,7 @@ def test_offline_message_queueing(self): bclient.disconnect() bclient.loop_stop() - oclient = paho.mqtt.client.Client( - clientid, protocol=paho.mqtt.client.MQTTv5) + oclient = paho.mqtt.client.Client(clientid, protocol=paho.mqtt.client.MQTTv5) ocallback.register(oclient) oclient.loop_start() oclient.connect(host="localhost", port=self._test_broker_port, clean_start=False) @@ -399,10 +376,8 @@ def test_offline_message_queueing(self): oclient.disconnect() oclient.loop_stop() - self.assertTrue(len(ocallback.messages) in [ - 2, 3], len(ocallback.messages)) - logging.info("This server %s queueing QoS 0 messages for offline clients" % - ("is" if len(ocallback.messages) == 3 else "is not")) + self.assertTrue(len(ocallback.messages) in [2, 3], len(ocallback.messages)) + logging.info("This server %s queueing QoS 0 messages for offline clients" % ("is" if len(ocallback.messages) == 3 else "is not")) def test_overlapping_subscriptions(self): # overlapping subscriptions. When there is more than one matching subscription for the same client for a topic, @@ -411,30 +386,28 @@ def test_overlapping_subscriptions(self): ocallback = Callbacks() clientid = b"overlapping subscriptions" - oclient = paho.mqtt.client.Client( - clientid, protocol=paho.mqtt.client.MQTTv5) + oclient = paho.mqtt.client.Client(clientid, protocol=paho.mqtt.client.MQTTv5) ocallback.register(oclient) oclient.loop_start() oclient.connect(host="localhost", port=self._test_broker_port) ocallback.wait_connected() - oclient.subscribe([(wildtopics[6], SubscribeOptions(qos=2)), - (wildtopics[0], SubscribeOptions(qos=1))]) + oclient.subscribe([(wildtopics[6], SubscribeOptions(qos=2)), (wildtopics[0], SubscribeOptions(qos=1))]) ocallback.wait_subscribed() oclient.publish(topics[3], b"overlapping topic filters", 2) ocallback.wait_published() time.sleep(1) self.assertTrue(len(ocallback.messages) in [1, 2], ocallback.messages) if len(ocallback.messages) == 1: - logging.info( - "This server is publishing one message for all matching overlapping subscriptions, not one for each.") - self.assertEqual( - ocallback.messages[0]["message"].qos, 2, ocallback.messages[0]["message"].qos) + logging.info("This server is publishing one message for all matching overlapping subscriptions, not one for each.") + self.assertEqual(ocallback.messages[0]["message"].qos, 2, ocallback.messages[0]["message"].qos) else: - logging.info( - "This server is publishing one message per each matching overlapping subscription.") - self.assertTrue((ocallback.messages[0]["message"].qos == 2 and ocallback.messages[1]["message"].qos == 1) or - (ocallback.messages[0]["message"].qos == 1 and ocallback.messages[1]["message"].qos == 2), callback.messages) + logging.info("This server is publishing one message per each matching overlapping subscription.") + self.assertTrue( + (ocallback.messages[0]["message"].qos == 2 and ocallback.messages[1]["message"].qos == 1) + or (ocallback.messages[0]["message"].qos == 1 and ocallback.messages[1]["message"].qos == 2), + callback.messages, + ) oclient.disconnect() oclient.loop_stop() ocallback.clear() @@ -446,8 +419,7 @@ def test_subscribe_failure(self): ocallback = Callbacks() clientid = b"subscribe failure" - oclient = paho.mqtt.client.Client( - clientid, protocol=paho.mqtt.client.MQTTv5) + oclient = paho.mqtt.client.Client(clientid, protocol=paho.mqtt.client.MQTTv5) ocallback.register(oclient) oclient.loop_start() oclient.connect(host="localhost", port=self._test_broker_port) @@ -455,8 +427,7 @@ def test_subscribe_failure(self): oclient.subscribe(nosubscribe_topics[0], qos=2) response = ocallback.wait_subscribed() - self.assertEqual(response["reasonCodes"][0].getName(), "Unspecified error", - f"return code should be 0x80 {response['reasonCodes'][0].getName()}") + self.assertEqual(response["reasonCodes"][0].getName(), "Unspecified error", f"return code should be 0x80 {response['reasonCodes'][0].getName()}") oclient.disconnect() oclient.loop_stop() @@ -493,8 +464,7 @@ def test_unsubscribe(self): def new_client(self, clientid): callback = Callbacks() - client = paho.mqtt.client.Client(clientid.encode( - "utf-8"), protocol=paho.mqtt.client.MQTTv5) + client = paho.mqtt.client.Client(clientid.encode("utf-8"), protocol=paho.mqtt.client.MQTTv5) callback.register(client) client.loop_start() return client, callback @@ -522,8 +492,7 @@ def test_session_expiry(self): fclient, fcallback = self.new_client(clientid) # session should immediately expire - fclient.connect_async(host="localhost", port=self._test_broker_port, clean_start=False, - properties=connect_properties) + fclient.connect_async(host="localhost", port=self._test_broker_port, clean_start=False, properties=connect_properties) connack = fcallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) @@ -547,8 +516,7 @@ def test_session_expiry(self): time.sleep(2) # session should still exist fclient, fcallback = self.new_client(clientid) - fclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, - properties=connect_properties) + fclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, properties=connect_properties) connack = fcallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], True) @@ -559,8 +527,7 @@ def test_session_expiry(self): time.sleep(6) # session should not exist fclient, fcallback = self.new_client(clientid) - fclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, - properties=connect_properties) + fclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, properties=connect_properties) connack = fcallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) @@ -570,8 +537,7 @@ def test_session_expiry(self): eclient, ecallback = self.new_client(clientid) connect_properties.SessionExpiryInterval = 1 - connack = eclient.connect( - host="localhost", port=self._test_broker_port, properties=connect_properties) + connack = eclient.connect(host="localhost", port=self._test_broker_port, properties=connect_properties) connack = ecallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) @@ -586,8 +552,7 @@ def test_session_expiry(self): time.sleep(3) # session should still exist as we changed the expiry interval on disconnect fclient, fcallback = self.new_client(clientid) - fclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, - properties=connect_properties) + fclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, properties=connect_properties) connack = fcallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], True) @@ -598,8 +563,7 @@ def test_session_expiry(self): # session should immediately expire fclient, fcallback = self.new_client(clientid) - fclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, - properties=connect_properties) + fclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, properties=connect_properties) connack = fcallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) @@ -623,29 +587,23 @@ def test_user_properties(self): publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.UserProperty = ("a", "2") publish_properties.UserProperty = ("c", "3") - uclient.publish(topics[0], b"", 0, retain=False, - properties=publish_properties) - uclient.publish(topics[0], b"", 1, retain=False, - properties=publish_properties) - uclient.publish(topics[0], b"", 2, retain=False, - properties=publish_properties) + uclient.publish(topics[0], b"", 0, retain=False, properties=publish_properties) + uclient.publish(topics[0], b"", 1, retain=False, properties=publish_properties) + uclient.publish(topics[0], b"", 2, retain=False, properties=publish_properties) count = 0 while len(ucallback.messages) < 3 and count < 50: - time.sleep(.1) + time.sleep(0.1) count += 1 uclient.disconnect() ucallback.wait_disconnected() uclient.loop_stop() self.assertEqual(len(ucallback.messages), 3, ucallback.messages) userprops = ucallback.messages[0]["message"].properties.UserProperty - self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [ - ("c", "3"), ("a", "2")]], userprops) + self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [("c", "3"), ("a", "2")]], userprops) userprops = ucallback.messages[1]["message"].properties.UserProperty - self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [ - ("c", "3"), ("a", "2")]], userprops) + self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [("c", "3"), ("a", "2")]], userprops) userprops = ucallback.messages[2]["message"].properties.UserProperty - self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [ - ("c", "3"), ("a", "2")]], userprops) + self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [("c", "3"), ("a", "2")]], userprops) qoss = [ucallback.messages[i]["message"].qos for i in range(3)] self.assertTrue(1 in qoss and 2 in qoss and 0 in qoss, qoss) @@ -661,19 +619,16 @@ def test_payload_format(self): publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.PayloadFormatIndicator = 1 publish_properties.ContentType = "My name" - info = pclient.publish( - topics[0], b"qos 0", 0, retain=False, properties=publish_properties) + info = pclient.publish(topics[0], b"qos 0", 0, retain=False, properties=publish_properties) info.wait_for_publish() - info = pclient.publish( - topics[0], b"qos 1", 1, retain=False, properties=publish_properties) + info = pclient.publish(topics[0], b"qos 1", 1, retain=False, properties=publish_properties) info.wait_for_publish() - info = pclient.publish( - topics[0], b"qos 2", 2, retain=False, properties=publish_properties) + info = pclient.publish(topics[0], b"qos 2", 2, retain=False, properties=publish_properties) info.wait_for_publish() count = 0 while len(pcallback.messages) < 3 and count < 50: - time.sleep(.1) + time.sleep(0.1) count += 1 pclient.disconnect() pcallback.wait_disconnected() @@ -682,16 +637,13 @@ def test_payload_format(self): self.assertEqual(len(pcallback.messages), 3, pcallback.messages) props = pcallback.messages[0]["message"].properties self.assertEqual(props.ContentType, "My name", props.ContentType) - self.assertEqual(props.PayloadFormatIndicator, - 1, props.PayloadFormatIndicator) + self.assertEqual(props.PayloadFormatIndicator, 1, props.PayloadFormatIndicator) props = pcallback.messages[1]["message"].properties self.assertEqual(props.ContentType, "My name", props.ContentType) - self.assertEqual(props.PayloadFormatIndicator, - 1, props.PayloadFormatIndicator) + self.assertEqual(props.PayloadFormatIndicator, 1, props.PayloadFormatIndicator) props = pcallback.messages[2]["message"].properties self.assertEqual(props.ContentType, "My name", props.ContentType) - self.assertEqual(props.PayloadFormatIndicator, - 1, props.PayloadFormatIndicator) + self.assertEqual(props.PayloadFormatIndicator, 1, props.PayloadFormatIndicator) qoss = [pcallback.messages[i]["message"].qos for i in range(3)] self.assertTrue(1 in qoss and 2 in qoss and 0 in qoss, qoss) @@ -718,17 +670,13 @@ def test_message_expiry(self): laclient.connect(host="localhost", port=self._test_broker_port) publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.MessageExpiryInterval = 1 - laclient.publish(topics[0], b"qos 1 - expire", 1, - retain=False, properties=publish_properties) - laclient.publish(topics[0], b"qos 2 - expire", 2, - retain=False, properties=publish_properties) + laclient.publish(topics[0], b"qos 1 - expire", 1, retain=False, properties=publish_properties) + laclient.publish(topics[0], b"qos 2 - expire", 2, retain=False, properties=publish_properties) publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.MessageExpiryInterval = 6 - laclient.publish(topics[0], b"qos 1 - don't expire", - 1, retain=False, properties=publish_properties) - laclient.publish(topics[0], b"qos 2 - don't expire", - 2, retain=False, properties=publish_properties) + laclient.publish(topics[0], b"qos 1 - don't expire", 1, retain=False, properties=publish_properties) + laclient.publish(topics[0], b"qos 2 - don't expire", 2, retain=False, properties=publish_properties) time.sleep(3) lbclient, lbcallback = self.new_client(f"{clientid} b") @@ -738,10 +686,8 @@ def test_message_expiry(self): self.waitfor(lbcallback.messages, 1, 3) time.sleep(1) self.assertEqual(len(lbcallback.messages), 2, lbcallback.messages) - self.assertTrue(lbcallback.messages[0]["message"].properties.MessageExpiryInterval < 6, - lbcallback.messages[0]["message"].properties.MessageExpiryInterval) - self.assertTrue(lbcallback.messages[1]["message"].properties.MessageExpiryInterval < 6, - lbcallback.messages[1]["message"].properties.MessageExpiryInterval) + self.assertTrue(lbcallback.messages[0]["message"].properties.MessageExpiryInterval < 6, lbcallback.messages[0]["message"].properties.MessageExpiryInterval) + self.assertTrue(lbcallback.messages[1]["message"].properties.MessageExpiryInterval < 6, lbcallback.messages[1]["message"].properties.MessageExpiryInterval) laclient.disconnect() lacallback.wait_disconnected() laclient.loop_stop() @@ -752,22 +698,20 @@ def test_message_expiry(self): def test_subscribe_options(self): # noLocal - clientid = 'subscribe options - noLocal' + clientid = "subscribe options - noLocal" laclient, lacallback = self.new_client(f"{clientid} a") laclient.connect(host="localhost", port=self._test_broker_port) lacallback.wait_connected() laclient.loop_start() - laclient.subscribe( - topics[0], options=SubscribeOptions(qos=2, noLocal=True)) + laclient.subscribe(topics[0], options=SubscribeOptions(qos=2, noLocal=True)) lacallback.wait_subscribed() lbclient, lbcallback = self.new_client(f"{clientid} b") lbclient.connect(host="localhost", port=self._test_broker_port) lbcallback.wait_connected() lbclient.loop_start() - lbclient.subscribe( - topics[0], options=SubscribeOptions(qos=2, noLocal=True)) + lbclient.subscribe(topics[0], options=SubscribeOptions(qos=2, noLocal=True)) lbcallback.wait_subscribed() laclient.publish(topics[0], b"noLocal test", 1, retain=False) @@ -784,18 +728,15 @@ def test_subscribe_options(self): lbclient.loop_stop() # retainAsPublished - clientid = 'subscribe options - retain as published' + clientid = "subscribe options - retain as published" laclient, lacallback = self.new_client(f"{clientid} a") laclient.connect(host="localhost", port=self._test_broker_port) lacallback.wait_connected() - laclient.subscribe(topics[0], options=SubscribeOptions( - qos=2, retainAsPublished=True)) + laclient.subscribe(topics[0], options=SubscribeOptions(qos=2, retainAsPublished=True)) lacallback.wait_subscribed() self.waitfor(lacallback.subscribeds, 1, 3) - laclient.publish( - topics[0], b"retain as published false", 1, retain=False) - laclient.publish( - topics[0], b"retain as published true", 1, retain=True) + laclient.publish(topics[0], b"retain as published false", 1, retain=False) + laclient.publish(topics[0], b"retain as published true", 1, retain=True) self.waitfor(lacallback.messages, 2, 3) time.sleep(1) @@ -808,7 +749,7 @@ def test_subscribe_options(self): self.assertEqual(lacallback.messages[1]["message"].retain, True) # retainHandling - clientid = 'subscribe options - retain handling' + clientid = "subscribe options - retain handling" laclient, lacallback = self.new_client(f"{clientid} a") laclient.connect(host="localhost", port=self._test_broker_port) lacallback.wait_connected() @@ -818,15 +759,13 @@ def test_subscribe_options(self): time.sleep(1) # retain handling 1 only gives us retained messages on a new subscription - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) lacallback.wait_subscribed() self.assertEqual(len(lacallback.messages), 3) qoss = [lacallback.messages[i]["message"].qos for i in range(3)] self.assertTrue(1 in qoss and 2 in qoss and 0 in qoss, qoss) lacallback.clear() - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) lacallback.wait_subscribed() time.sleep(1) self.assertEqual(len(lacallback.messages), 0) @@ -839,15 +778,13 @@ def test_subscribe_options(self): lacallback.wait_unsubscribed() # check that we really did remove that subscription - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) lacallback.wait_subscribed() self.assertEqual(len(lacallback.messages), 3) qoss = [lacallback.messages[i]["message"].qos for i in range(3)] self.assertTrue(1 in qoss and 2 in qoss and 0 in qoss, qoss) lacallback.clear() - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) lacallback.wait_subscribed() time.sleep(1) self.assertEqual(len(lacallback.messages), 0) @@ -860,12 +797,10 @@ def test_subscribe_options(self): lacallback.wait_unsubscribed() lacallback.clear() - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=2)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=2)) lacallback.wait_subscribed() self.assertEqual(len(lacallback.messages), 0) - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=2)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=2)) lacallback.wait_subscribed() self.assertEqual(len(lacallback.messages), 0) @@ -873,15 +808,13 @@ def test_subscribe_options(self): laclient.unsubscribe(wildtopics[5]) lacallback.wait_unsubscribed() - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=0)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=0)) lacallback.wait_subscribed() self.assertEqual(len(lacallback.messages), 3) qoss = [lacallback.messages[i]["message"].qos for i in range(3)] self.assertTrue(1 in qoss and 2 in qoss and 0 in qoss, qoss) lacallback.clear() - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=0)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=0)) time.sleep(1) self.assertEqual(len(lacallback.messages), 3) qoss = [lacallback.messages[i]["message"].qos for i in range(3)] @@ -893,7 +826,7 @@ def test_subscribe_options(self): cleanRetained(self._test_broker_port) def test_subscription_identifiers(self): - clientid = 'subscription identifiers' + clientid = "subscription identifiers" laclient, lacallback = self.new_client(f"{clientid} a") laclient.connect(host="localhost", port=self._test_broker_port) @@ -922,8 +855,9 @@ def test_subscription_identifiers(self): self.waitfor(lacallback.messages, 1, 3) self.assertEqual(len(lacallback.messages), 1, lacallback.messages) - self.assertEqual(lacallback.messages[0]["message"].properties.SubscriptionIdentifier[0], - 456789, lacallback.messages[0]["message"].properties.SubscriptionIdentifier) + self.assertEqual( + lacallback.messages[0]["message"].properties.SubscriptionIdentifier[0], 456789, lacallback.messages[0]["message"].properties.SubscriptionIdentifier + ) laclient.disconnect() lacallback.wait_disconnected() laclient.loop_stop() @@ -931,15 +865,14 @@ def test_subscription_identifiers(self): self.waitfor(lbcallback.messages, 1, 3) self.assertEqual(len(lbcallback.messages), 1, lbcallback.messages) expected_subsids = set([2, 3]) - received_subsids = set( - lbcallback.messages[0]["message"].properties.SubscriptionIdentifier) + received_subsids = set(lbcallback.messages[0]["message"].properties.SubscriptionIdentifier) self.assertEqual(received_subsids, expected_subsids, received_subsids) lbclient.disconnect() lbcallback.wait_disconnected() lbclient.loop_stop() def test_request_response(self): - clientid = 'request response' + clientid = "request response" laclient, lacallback = self.new_client(f"{clientid} a") laclient.connect(host="localhost", port=self._test_broker_port) @@ -951,31 +884,25 @@ def test_request_response(self): lbcallback.wait_connected() lbclient.loop_start() - laclient.subscribe( - topics[0], options=SubscribeOptions(2, noLocal=True)) + laclient.subscribe(topics[0], options=SubscribeOptions(2, noLocal=True)) lacallback.wait_subscribed() - lbclient.subscribe( - topics[0], options=SubscribeOptions(2, noLocal=True)) + lbclient.subscribe(topics[0], options=SubscribeOptions(2, noLocal=True)) lbcallback.wait_subscribed() publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.ResponseTopic = topics[0] publish_properties.CorrelationData = b"334" # client a is the requester - laclient.publish(topics[0], b"request", 1, - properties=publish_properties) + laclient.publish(topics[0], b"request", 1, properties=publish_properties) # client b is the responder self.waitfor(lbcallback.messages, 1, 3) self.assertEqual(len(lbcallback.messages), 1, lbcallback.messages) - self.assertEqual(lbcallback.messages[0]["message"].properties.ResponseTopic, topics[0], - lbcallback.messages[0]["message"].properties) - self.assertEqual(lbcallback.messages[0]["message"].properties.CorrelationData, b"334", - lbcallback.messages[0]["message"].properties) + self.assertEqual(lbcallback.messages[0]["message"].properties.ResponseTopic, topics[0], lbcallback.messages[0]["message"].properties) + self.assertEqual(lbcallback.messages[0]["message"].properties.CorrelationData, b"334", lbcallback.messages[0]["message"].properties) - lbclient.publish(lbcallback.messages[0]["message"].properties.ResponseTopic, b"response", 1, - properties=lbcallback.messages[0]["message"].properties) + lbclient.publish(lbcallback.messages[0]["message"].properties.ResponseTopic, b"response", 1, properties=lbcallback.messages[0]["message"].properties) # client a gets the response self.waitfor(lacallback.messages, 1, 3) @@ -989,7 +916,7 @@ def test_request_response(self): lbclient.loop_stop() def test_client_topic_alias(self): - clientid = 'client topic alias' + clientid = "client topic alias" connect_properties = Properties(PacketTypes.CONNECT) connect_properties.TopicAliasMaximum = 0 # server topic aliases not allowed @@ -1012,13 +939,11 @@ def test_client_topic_alias(self): publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.TopicAlias = 1 - laclient.publish(topics[0], b"topic alias 1", - 1, properties=publish_properties) + laclient.publish(topics[0], b"topic alias 1", 1, properties=publish_properties) self.waitfor(lacallback.messages, 1, 3) self.assertEqual(len(lacallback.messages), 1, lacallback.messages) - laclient.publish("", b"topic alias 2", 1, - properties=publish_properties) + laclient.publish("", b"topic alias 2", 1, properties=publish_properties) self.waitfor(lacallback.messages, 2, 3) self.assertEqual(len(lacallback.messages), 2, lacallback.messages) @@ -1028,8 +953,7 @@ def test_client_topic_alias(self): # check aliases have been deleted laclient, lacallback = self.new_client(f"{clientid} a") - laclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, - properties=connect_properties) + laclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, properties=connect_properties) laclient.publish(topics[0], b"topic alias 3", 1) self.waitfor(lacallback.messages, 1, 3) @@ -1037,15 +961,14 @@ def test_client_topic_alias(self): publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.TopicAlias = 1 - laclient.publish("", b"topic alias 4", 1, - properties=publish_properties) + laclient.publish("", b"topic alias 4", 1, properties=publish_properties) # should get back a disconnect with Topic alias invalid lacallback.wait_disconnected() laclient.loop_stop() def test_server_topic_alias(self): - clientid = 'server topic alias' + clientid = "server topic alias" serverTopicAliasMaximum = 1 # server topic alias allowed connect_properties = Properties(PacketTypes.CONNECT) @@ -1068,19 +991,16 @@ def test_server_topic_alias(self): laclient.loop_stop() # first message should set the topic alias - self.assertTrue(hasattr( - lacallback.messages[0]["message"].properties, "TopicAlias"), lacallback.messages[0]["message"].properties) + self.assertTrue(hasattr(lacallback.messages[0]["message"].properties, "TopicAlias"), lacallback.messages[0]["message"].properties) topicalias = lacallback.messages[0]["message"].properties.TopicAlias self.assertTrue(topicalias > 0) self.assertEqual(lacallback.messages[0]["message"].topic, topics[0]) - self.assertEqual( - lacallback.messages[1]["message"].properties.TopicAlias, topicalias) + self.assertEqual(lacallback.messages[1]["message"].properties.TopicAlias, topicalias) self.assertEqual(lacallback.messages[1]["message"].topic, "") - self.assertEqual( - lacallback.messages[2]["message"].properties.TopicAlias, topicalias) + self.assertEqual(lacallback.messages[2]["message"].properties.TopicAlias, topicalias) self.assertEqual(lacallback.messages[2]["message"].topic, "") serverTopicAliasMaximum = 0 # no server topic alias allowed @@ -1104,12 +1024,9 @@ def test_server_topic_alias(self): laclient.loop_stop() # No topic aliases - self.assertFalse(hasattr( - lacallback.messages[0]["message"].properties, "TopicAlias"), lacallback.messages[0]["message"].properties) - self.assertFalse(hasattr( - lacallback.messages[1]["message"].properties, "TopicAlias"), lacallback.messages[1]["message"].properties) - self.assertFalse(hasattr( - lacallback.messages[2]["message"].properties, "TopicAlias"), lacallback.messages[2]["message"].properties) + self.assertFalse(hasattr(lacallback.messages[0]["message"].properties, "TopicAlias"), lacallback.messages[0]["message"].properties) + self.assertFalse(hasattr(lacallback.messages[1]["message"].properties, "TopicAlias"), lacallback.messages[1]["message"].properties) + self.assertFalse(hasattr(lacallback.messages[2]["message"].properties, "TopicAlias"), lacallback.messages[2]["message"].properties) serverTopicAliasMaximum = 0 # no server topic alias allowed connect_properties = Properties(PacketTypes.CONNECT) @@ -1132,15 +1049,12 @@ def test_server_topic_alias(self): laclient.loop_stop() # No topic aliases - self.assertFalse(hasattr( - lacallback.messages[0]["message"].properties, "TopicAlias"), lacallback.messages[0]["message"].properties) - self.assertFalse(hasattr( - lacallback.messages[1]["message"].properties, "TopicAlias"), lacallback.messages[1]["message"].properties) - self.assertFalse(hasattr( - lacallback.messages[2]["message"].properties, "TopicAlias"), lacallback.messages[2]["message"].properties) + self.assertFalse(hasattr(lacallback.messages[0]["message"].properties, "TopicAlias"), lacallback.messages[0]["message"].properties) + self.assertFalse(hasattr(lacallback.messages[1]["message"].properties, "TopicAlias"), lacallback.messages[1]["message"].properties) + self.assertFalse(hasattr(lacallback.messages[2]["message"].properties, "TopicAlias"), lacallback.messages[2]["message"].properties) def test_maximum_packet_size(self): - clientid = 'maximum packet size' + clientid = "maximum packet size" # 1. server max packet size laclient, lacallback = self.new_client(f"{clientid} a") @@ -1148,20 +1062,18 @@ def test_maximum_packet_size(self): connack = lacallback.wait_connected() laclient.loop_start() - serverMaximumPacketSize = 2**28-1 + serverMaximumPacketSize = 2**28 - 1 if hasattr(connack["properties"], "MaximumPacketSize"): serverMaximumPacketSize = connack["properties"].MaximumPacketSize if serverMaximumPacketSize < 65535: # publish bigger packet than server can accept - payload = b"."*serverMaximumPacketSize + payload = b"." * serverMaximumPacketSize laclient.publish(topics[0], payload, 0) # should get back a disconnect with packet size too big response = lacallback.wait_disconnected() - self.assertEqual(len(lacallback.disconnecteds), - 0, lacallback.disconnecteds) - self.assertEqual(response["reasonCode"].getName(), - "Packet too large", response["reasonCode"].getName()) + self.assertEqual(len(lacallback.disconnecteds), 0, lacallback.disconnecteds) + self.assertEqual(response["reasonCode"].getName(), "Packet too large", response["reasonCode"].getName()) else: laclient.disconnect() lacallback.wait_disconnected() @@ -1177,7 +1089,7 @@ def test_maximum_packet_size(self): connack = lacallback.wait_connected() laclient.loop_start() - serverMaximumPacketSize = 2**28-1 + serverMaximumPacketSize = 2**28 - 1 if hasattr(connack["properties"], "MaximumPacketSize"): serverMaximumPacketSize = connack["properties"].MaximumPacketSize @@ -1185,13 +1097,13 @@ def test_maximum_packet_size(self): response = lacallback.wait_subscribed() # send a small enough packet, should get this one back - payload = b"."*(int(maximumPacketSize/2)) + payload = b"." * (int(maximumPacketSize / 2)) laclient.publish(topics[0], payload, 0) self.waitfor(lacallback.messages, 1, 3) self.assertEqual(len(lacallback.messages), 1, lacallback.messages) # send a packet too big to receive - payload = b"."*maximumPacketSize + payload = b"." * maximumPacketSize laclient.publish(topics[0], payload, 1) self.waitfor(lacallback.messages, 2, 3) self.assertEqual(len(lacallback.messages), 1, lacallback.messages) @@ -1220,7 +1132,7 @@ def test_server_keep_alive(self): def test_will_delay(self): # the will message should be received earlier than the session expiry - clientid = 'will delay' + clientid = "will delay" will_properties = Properties(PacketTypes.WILLMESSAGE) connect_properties = Properties(PacketTypes.CONNECT) @@ -1231,8 +1143,7 @@ def test_will_delay(self): connect_properties.SessionExpiryInterval = 5 laclient, lacallback = self.new_client(f"{clientid} a") - laclient.will_set( - topics[0], payload=b"test_will_delay will message", properties=will_properties) + laclient.will_set(topics[0], payload=b"test_will_delay will message", properties=will_properties) laclient.connect(host="localhost", port=self._test_broker_port, properties=connect_properties) connack = lacallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") @@ -1252,19 +1163,18 @@ def test_will_delay(self): laclient.socket().close() start = time.time() while lbcallback.messages == []: - time.sleep(.1) + time.sleep(0.1) duration = time.time() - start self.assertAlmostEqual(duration, 4, delta=1) self.assertEqual(lbcallback.messages[0]["message"].topic, topics[0]) - self.assertEqual( - lbcallback.messages[0]["message"].payload, b"test_will_delay will message") + self.assertEqual(lbcallback.messages[0]["message"].payload, b"test_will_delay will message") lbclient.disconnect() lbcallback.wait_disconnected() lbclient.loop_stop() def test_shared_subscriptions(self): - clientid = 'shared subscriptions' + clientid = "shared subscriptions" shared_sub_topic = f"$share/sharename/{topic_prefix}x" shared_pub_topic = f"{topic_prefix}x" @@ -1277,8 +1187,7 @@ def test_shared_subscriptions(self): self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) - laclient.subscribe( - [(shared_sub_topic, SubscribeOptions(2)), (topics[0], SubscribeOptions(2))]) + laclient.subscribe([(shared_sub_topic, SubscribeOptions(2)), (topics[0], SubscribeOptions(2))]) lacallback.wait_subscribed() lbclient, lbcallback = self.new_client(f"{clientid} b") @@ -1289,8 +1198,7 @@ def test_shared_subscriptions(self): self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) - lbclient.subscribe( - [(shared_sub_topic, SubscribeOptions(2)), (topics[0], 2)]) + lbclient.subscribe([(shared_sub_topic, SubscribeOptions(2)), (topics[0], 2)]) lbcallback.wait_subscribed() lacallback.clear() @@ -1300,8 +1208,8 @@ def test_shared_subscriptions(self): for i in range(count): lbclient.publish(topics[0], f"message {i}", 0) j = 0 - while len(lacallback.messages) + len(lbcallback.messages) < 2*count and j < 20: - time.sleep(.1) + while len(lacallback.messages) + len(lbcallback.messages) < 2 * count and j < 20: + time.sleep(0.1) j += 1 time.sleep(1) self.assertEqual(len(lacallback.messages), count) @@ -1314,12 +1222,11 @@ def test_shared_subscriptions(self): lbclient.publish(shared_pub_topic, f"message {i}", 0) j = 0 while len(lacallback.messages) + len(lbcallback.messages) < count and j < 20: - time.sleep(.1) + time.sleep(0.1) j += 1 time.sleep(1) # Each message should only be received once - self.assertEqual(len(lacallback.messages) + - len(lbcallback.messages), count) + self.assertEqual(len(lacallback.messages) + len(lbcallback.messages), count) laclient.disconnect() lacallback.wait_disconnected() diff --git a/tests/test_websocket_integration.py b/tests/test_websocket_integration.py index 80872f9e..458a241f 100644 --- a/tests/test_websocket_integration.py +++ b/tests/test_websocket_integration.py @@ -14,44 +14,47 @@ @pytest.fixture def init_response_headers(): # "Normal" websocket response from server - response_headers = OrderedDict([ - ("Upgrade", "websocket"), - ("Connection", "Upgrade"), - ("Sec-WebSocket-Accept", "testwebsocketkey"), - ("Sec-WebSocket-Protocol", "chat"), - ]) + response_headers = OrderedDict( + [ + ("Upgrade", "websocket"), + ("Connection", "Upgrade"), + ("Sec-WebSocket-Accept", "testwebsocketkey"), + ("Sec-WebSocket-Protocol", "chat"), + ] + ) return response_headers def get_websocket_response(response_headers): - """ Takes headers and constructs HTTP response + """Takes headers and constructs HTTP response 'HTTP/1.1 101 Switching Protocols' is the headers for the response, as expected in client.py """ - response = "\r\n".join([ - "HTTP/1.1 101 Switching Protocols", - "\r\n".join(f"{i}: {j}" for i, j in response_headers.items()), - "\r\n", - ]).encode("utf8") + response = "\r\n".join( + [ + "HTTP/1.1 101 Switching Protocols", + "\r\n".join(f"{i}: {j}" for i, j in response_headers.items()), + "\r\n", + ] + ).encode("utf8") return response -@pytest.mark.parametrize("proto_ver,proto_name", [ - (client.MQTTv31, "MQIsdp"), - (client.MQTTv311, "MQTT"), -]) +@pytest.mark.parametrize( + "proto_ver,proto_name", + [ + (client.MQTTv31, "MQIsdp"), + (client.MQTTv311, "MQTT"), + ], +) class TestInvalidWebsocketResponse: def test_unexpected_response(self, proto_ver, proto_name, fake_websocket_broker): - """ Server responds with a valid code, but it's not what the client expected """ + """Server responds with a valid code, but it's not what the client expected""" - mqttc = client.Client( - "test_unexpected_response", - protocol=proto_ver, - transport="websockets" - ) + mqttc = client.Client("test_unexpected_response", protocol=proto_ver, transport="websockets") class WebsocketHandler(socketserver.BaseRequestHandler): def handle(_self): @@ -64,15 +67,18 @@ def handle(_self): assert str(exc.value) == "WebSocket handshake error" -@pytest.mark.parametrize("proto_ver,proto_name", [ - (client.MQTTv31, "MQIsdp"), - (client.MQTTv311, "MQTT"), -]) +@pytest.mark.parametrize( + "proto_ver,proto_name", + [ + (client.MQTTv31, "MQIsdp"), + (client.MQTTv311, "MQTT"), + ], +) class TestBadWebsocketHeaders: - """ Testing for basic functionality in checking for headers """ + """Testing for basic functionality in checking for headers""" def _get_basic_handler(self, response_headers): - """ Get a basic BaseRequestHandler which returns the information in + """Get a basic BaseRequestHandler which returns the information in self._response_headers """ @@ -81,21 +87,16 @@ def _get_basic_handler(self, response_headers): class WebsocketHandler(socketserver.BaseRequestHandler): def handle(_self): self.data = _self.request.recv(1024).strip() - print('Received', self.data.decode('utf8')) + print("Received", self.data.decode("utf8")) # Respond with data passed in to serve() _self.request.sendall(response) return WebsocketHandler - def test_no_upgrade(self, proto_ver, proto_name, fake_websocket_broker, - init_response_headers): - """ Server doesn't respond with 'connection: upgrade' """ + def test_no_upgrade(self, proto_ver, proto_name, fake_websocket_broker, init_response_headers): + """Server doesn't respond with 'connection: upgrade'""" - mqttc = client.Client( - "test_no_upgrade", - protocol=proto_ver, - transport="websockets" - ) + mqttc = client.Client("test_no_upgrade", protocol=proto_ver, transport="websockets") init_response_headers["Connection"] = "bad" response = self._get_basic_handler(init_response_headers) @@ -105,15 +106,10 @@ def test_no_upgrade(self, proto_ver, proto_name, fake_websocket_broker, assert str(exc.value) == "WebSocket handshake error, connection not upgraded" - def test_bad_secret_key(self, proto_ver, proto_name, fake_websocket_broker, - init_response_headers): - """ Server doesn't give anything after connection: upgrade """ + def test_bad_secret_key(self, proto_ver, proto_name, fake_websocket_broker, init_response_headers): + """Server doesn't give anything after connection: upgrade""" - mqttc = client.Client( - "test_bad_secret_key", - protocol=proto_ver, - transport="websockets" - ) + mqttc = client.Client("test_bad_secret_key", protocol=proto_ver, transport="websockets") response = self._get_basic_handler(init_response_headers) @@ -123,22 +119,25 @@ def test_bad_secret_key(self, proto_ver, proto_name, fake_websocket_broker, assert str(exc.value) == "WebSocket handshake error, invalid secret key" -@pytest.mark.parametrize("proto_ver,proto_name", [ - (client.MQTTv31, "MQIsdp"), - (client.MQTTv311, "MQTT"), -]) +@pytest.mark.parametrize( + "proto_ver,proto_name", + [ + (client.MQTTv31, "MQIsdp"), + (client.MQTTv311, "MQTT"), + ], +) class TestValidHeaders: - """ Testing for functionality in request/response headers """ + """Testing for functionality in request/response headers""" def _get_callback_handler(self, response_headers, check_request=None): - """ Get a basic BaseRequestHandler which returns the information in + """Get a basic BaseRequestHandler which returns the information in self._response_headers """ class WebsocketHandler(socketserver.BaseRequestHandler): def handle(_self): self.data = _self.request.recv(1024).strip() - print('Received', self.data.decode('utf8')) + print("Received", self.data.decode("utf8")) decoded = self.data.decode("utf8") @@ -162,16 +161,10 @@ def handle(_self): return WebsocketHandler - def test_successful_connection(self, proto_ver, proto_name, - fake_websocket_broker, - init_response_headers): - """ Connect successfully, on correct path """ + def test_successful_connection(self, proto_ver, proto_name, fake_websocket_broker, init_response_headers): + """Connect successfully, on correct path""" - mqttc = client.Client( - "test_successful_connection", - protocol=proto_ver, - transport="websockets" - ) + mqttc = client.Client("test_successful_connection", protocol=proto_ver, transport="websockets") response = self._get_callback_handler(init_response_headers) @@ -180,20 +173,17 @@ def test_successful_connection(self, proto_ver, proto_name, mqttc.disconnect() - @pytest.mark.parametrize("mqtt_path", [ - "/mqtt" - "/special", - None, - ]) - def test_correct_path(self, proto_ver, proto_name, fake_websocket_broker, - mqtt_path, init_response_headers): - """ Make sure it can connect on user specified paths """ - - mqttc = client.Client( - "test_correct_path", - protocol=proto_ver, - transport="websockets" - ) + @pytest.mark.parametrize( + "mqtt_path", + [ + "/mqtt" "/special", + None, + ], + ) + def test_correct_path(self, proto_ver, proto_name, fake_websocket_broker, mqtt_path, init_response_headers): + """Make sure it can connect on user specified paths""" + + mqttc = client.Client("test_correct_path", protocol=proto_ver, transport="websockets") mqttc.ws_set_options( path=mqtt_path, @@ -214,21 +204,19 @@ def check_path_correct(decoded): mqttc.disconnect() - @pytest.mark.parametrize("auth_headers", [ - {"Authorization": "test123"}, - {"Authorization": "test123", "auth2": "abcdef"}, - # Won't be checked, but make sure it still works even if the user passes it - None, - ]) - def test_correct_auth(self, proto_ver, proto_name, fake_websocket_broker, - auth_headers, init_response_headers): - """ Make sure it sends the right auth headers """ - - mqttc = client.Client( - "test_correct_path", - protocol=proto_ver, - transport="websockets" - ) + @pytest.mark.parametrize( + "auth_headers", + [ + {"Authorization": "test123"}, + {"Authorization": "test123", "auth2": "abcdef"}, + # Won't be checked, but make sure it still works even if the user passes it + None, + ], + ) + def test_correct_auth(self, proto_ver, proto_name, fake_websocket_broker, auth_headers, init_response_headers): + """Make sure it sends the right auth headers""" + + mqttc = client.Client("test_correct_path", protocol=proto_ver, transport="websockets") mqttc.ws_set_options( headers=auth_headers, diff --git a/tests/test_websockets.py b/tests/test_websockets.py index f2605a3a..5fc267a9 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -6,10 +6,10 @@ class TestHeaders: - """ Make sure headers are used correctly """ + """Make sure headers are used correctly""" def test_normal_headers(self): - """ Normal headers as specified in RFC 6455 """ + """Normal headers as specified in RFC 6455""" response = [ "HTTP/1.1 101 Switching Protocols", @@ -54,15 +54,18 @@ def fakerecv(*args): # error assert str(exc.value) == "WebSocket handshake error, invalid secret key" - expected_sent = [i.format(**wargs) for i in [ - "GET {path:s} HTTP/1.1", - "Host: {host:s}", - "Upgrade: websocket", - "Connection: Upgrade", - "Sec-Websocket-Protocol: mqtt", - "Sec-Websocket-Version: 13", - "Origin: https://{host:s}:{port:d}", - ]] + expected_sent = [ + i.format(**wargs) + for i in [ + "GET {path:s} HTTP/1.1", + "Host: {host:s}", + "Upgrade: websocket", + "Connection: Upgrade", + "Sec-Websocket-Protocol: mqtt", + "Sec-Websocket-Version: 13", + "Origin: https://{host:s}:{port:d}", + ] + ] # Only sends the header once assert mocksock.send.call_count == 1 diff --git a/tests/testsupport/broker.py b/tests/testsupport/broker.py index 6cf72c7f..4f4cf7a4 100644 --- a/tests/testsupport/broker.py +++ b/tests/testsupport/broker.py @@ -22,7 +22,7 @@ def __init__(self): def start(self): if self._sock is None: - raise ValueError('Socket is not open') + raise ValueError("Socket is not open") (conn, address) = self._sock.accept() conn.settimeout(10) @@ -39,14 +39,14 @@ def finish(self): def receive_packet(self, num_bytes): if self._conn is None: - raise ValueError('Connection is not open') + raise ValueError("Connection is not open") packet_in = self._conn.recv(num_bytes) return packet_in def send_packet(self, packet_out): if self._conn is None: - raise ValueError('Connection is not open') + raise ValueError("Connection is not open") count = self._conn.send(packet_out) return count