From 50b280cac975b639872fcb0b2223189c128836e0 Mon Sep 17 00:00:00 2001 From: Pierre Fersing Date: Tue, 9 Jan 2024 20:58:18 +0100 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Aarni Koskela --- src/paho/mqtt/client.py | 58 +++++++++++++++++++++------------------- src/paho/mqtt/enums.py | 28 +++++++++---------- src/paho/mqtt/publish.py | 5 ++-- 3 files changed, 46 insertions(+), 45 deletions(-) diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index 7ee97206..eb4bc0bc 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -31,11 +31,10 @@ import struct import threading import time -import typing import urllib.parse import urllib.request import uuid -from typing import Any, Callable, Dict, Iterator, List, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Sequence, Tuple, Union from .enums import ConnackCode, ConnectionState, LogLevel, MessageState, MessageType, MQTTErrorCode, MQTTProtocolVersion, PahoClientMode from .matcher import MQTTMatcher @@ -48,7 +47,7 @@ except ImportError: from typing_extensions import Literal # type: ignore -if typing.TYPE_CHECKING: +if TYPE_CHECKING: try: from typing import TypedDict # type: ignore except ImportError: @@ -147,22 +146,22 @@ class _OutPacket(TypedDict): CONNACK_REFUSED_NOT_AUTHORIZED = ConnackCode.CONNACK_REFUSED_NOT_AUTHORIZED # Connection state -mqtt_cs_new = ConnectionState.mqtt_cs_new -mqtt_cs_connected = ConnectionState.mqtt_cs_connected -mqtt_cs_disconnecting = ConnectionState.mqtt_cs_disconnecting -mqtt_cs_connect_async = ConnectionState.mqtt_cs_connect_async +mqtt_cs_new = ConnectionState.MQTT_CS_NEW +mqtt_cs_connected = ConnectionState.MQTT_CS_CONNECTED +mqtt_cs_disconnecting = ConnectionState.MQTT_CS_DISCONNECTING +mqtt_cs_connect_async = ConnectionState.MQTT_CS_CONNECT_ASYNC # Message state -mqtt_ms_invalid = MessageState.mqtt_ms_invalid -mqtt_ms_publish = MessageState.mqtt_ms_publish -mqtt_ms_wait_for_puback = MessageState.mqtt_ms_wait_for_puback -mqtt_ms_wait_for_pubrec = MessageState.mqtt_ms_wait_for_pubrec -mqtt_ms_resend_pubrel = MessageState.mqtt_ms_resend_pubrel -mqtt_ms_wait_for_pubrel = MessageState.mqtt_ms_wait_for_pubrel -mqtt_ms_resend_pubcomp = MessageState.mqtt_ms_resend_pubcomp -mqtt_ms_wait_for_pubcomp = MessageState.mqtt_ms_wait_for_pubcomp -mqtt_ms_send_pubrec = MessageState.mqtt_ms_send_pubrec -mqtt_ms_queued = MessageState.mqtt_ms_queued +mqtt_ms_invalid = MessageState.MQTT_MS_INVALID +mqtt_ms_publish = MessageState.MQTT_MS_PUBLISH +mqtt_ms_wait_for_puback = MessageState.MQTT_MS_WAIT_FOR_PUBACK +mqtt_ms_wait_for_pubrec = MessageState.MQTT_MS_WAIT_FOR_PUBREC +mqtt_ms_resend_pubrel = MessageState.MQTT_MS_RESEND_PUBREL +mqtt_ms_wait_for_pubrel = MessageState.MQTT_MS_WAIT_FOR_PUBREL +mqtt_ms_resend_pubcomp = MessageState.MQTT_MS_RESEND_PUBCOMP +mqtt_ms_wait_for_pubcomp = MessageState.MQTT_MS_WAIT_FOR_PUBCOMP +mqtt_ms_send_pubrec = MessageState.MQTT_MS_SEND_PUBREC +mqtt_ms_queued = MessageState.MQTT_MS_QUEUED MQTT_ERR_AGAIN = MQTTErrorCode.MQTT_ERR_AGAIN MQTT_ERR_SUCCESS = MQTTErrorCode.MQTT_ERR_SUCCESS @@ -636,7 +635,7 @@ def __init__( self._transport = transport.lower() self._protocol = protocol self._userdata = userdata - self._sock: socket.socket | WebsocketWrapper | ssl.SSLSocket | None = None + self._sock: SocketLike | None = None self._sockpairR: socket.socket | None = None self._sockpairW: socket.socket | None = None self._keepalive = 60 @@ -3969,13 +3968,20 @@ def _get_proxy(self) -> dict[str, Any] | None: return None def _create_socket(self) -> SocketLike: - tcp_sock = self._create_socket_connection() - sock = self._create_ssl_socket_if_enable(tcp_sock) + sock = self._create_socket_connection() + if self._ssl: + sock = self._ssl_wrap_socket(sock) if self._transport == "websockets": sock.settimeout(self._keepalive) - return WebsocketWrapper(sock, self._host, self._port, self._ssl, - self._websocket_path, self._websocket_extra_headers) + return WebsocketWrapper( + socket=sock, + host=self._host, + port=self._port, + is_ssl=self._ssl, + path=self._websocket_path, + extra_headers=self._websocket_extra_headers, + ) return sock @@ -3989,10 +3995,7 @@ def _create_socket_connection(self) -> _socket.socket: else: return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source) - def _create_ssl_socket_if_enable(self, tcp_sock: _socket.socket) -> _socket.socket|ssl.SSLSocket: - if not self._ssl: - return tcp_sock - + def _ssl_wrap_socket(self, tcp_sock: _socket.socket) -> ssl.SSLSocket: if self._ssl_context is None: raise ValueError( "Impossible condition. _ssl_context should never be None if _ssl is True" @@ -4017,8 +4020,7 @@ def _create_ssl_socket_if_enable(self, tcp_sock: _socket.socket) -> _socket.sock ) else: # If SSL context has already checked hostname, then don't need to do it again - if (hasattr(self._ssl_context, 'check_hostname') and - self._ssl_context.check_hostname): # type: ignore + if getattr(self._ssl_context, 'check_hostname', False): # type: ignore verify_host = False ssl_sock.settimeout(self._keepalive) diff --git a/src/paho/mqtt/enums.py b/src/paho/mqtt/enums.py index 617b0e2f..33f18cb6 100644 --- a/src/paho/mqtt/enums.py +++ b/src/paho/mqtt/enums.py @@ -64,23 +64,23 @@ class ConnackCode(enum.IntEnum): class ConnectionState(enum.IntEnum): - mqtt_cs_new = 0 - mqtt_cs_connected = 1 - mqtt_cs_disconnecting = 2 - mqtt_cs_connect_async = 3 + MQTT_CS_NEW = 0 + MQTT_CS_CONNECTED = 1 + MQTT_CS_DISCONNECTING = 2 + MQTT_CS_CONNECT_ASYNC = 3 class MessageState(enum.IntEnum): - mqtt_ms_invalid = 0 - mqtt_ms_publish = 1 - mqtt_ms_wait_for_puback = 2 - mqtt_ms_wait_for_pubrec = 3 - mqtt_ms_resend_pubrel = 4 - mqtt_ms_wait_for_pubrel = 5 - mqtt_ms_resend_pubcomp = 6 - mqtt_ms_wait_for_pubcomp = 7 - mqtt_ms_send_pubrec = 8 - mqtt_ms_queued = 9 + MQTT_MS_INVALID = 0 + MQTT_MS_PUBLISH = 1 + MQTT_MS_WAIT_FOR_PUBACK = 2 + MQTT_MS_WAIT_FOR_PUBREC = 3 + MQTT_MS_RESEND_PUBREL = 4 + MQTT_MS_WAIT_FOR_PUBREL = 5 + MQTT_MS_RESEND_PUBCOMP = 6 + MQTT_MS_WAIT_FOR_PUBCOMP = 7 + MQTT_MS_SEND_PUBREC = 8 + MQTT_MS_QUEUED = 9 class PahoClientMode(enum.IntEnum): diff --git a/src/paho/mqtt/publish.py b/src/paho/mqtt/publish.py index 1fdac2d6..4fce7081 100644 --- a/src/paho/mqtt/publish.py +++ b/src/paho/mqtt/publish.py @@ -21,14 +21,13 @@ from __future__ import annotations import collections -import typing from collections.abc import Iterable -from typing import Any, List, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Tuple, Union from .. import mqtt from . import client as paho -if typing.TYPE_CHECKING: +if TYPE_CHECKING: try: from typing import NotRequired, Required, TypedDict # type: ignore except ImportError: