Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Aarni Koskela <[email protected]>
  • Loading branch information
PierreF and akx committed Jan 9, 2024
1 parent 9e51f1c commit 50b280c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 45 deletions.
58 changes: 30 additions & 28 deletions src/paho/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@
import struct
import threading
import time
import typing
import urllib.parse
import urllib.request
import uuid
from typing import Any, Callable, Dict, Iterator, List, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Sequence, Tuple, Union

from .enums import ConnackCode, ConnectionState, LogLevel, MessageState, MessageType, MQTTErrorCode, MQTTProtocolVersion, PahoClientMode
from .matcher import MQTTMatcher
Expand All @@ -48,7 +47,7 @@
except ImportError:
from typing_extensions import Literal # type: ignore

if typing.TYPE_CHECKING:
if TYPE_CHECKING:
try:
from typing import TypedDict # type: ignore
except ImportError:
Expand Down Expand Up @@ -147,22 +146,22 @@ class _OutPacket(TypedDict):
CONNACK_REFUSED_NOT_AUTHORIZED = ConnackCode.CONNACK_REFUSED_NOT_AUTHORIZED

# Connection state
mqtt_cs_new = ConnectionState.mqtt_cs_new
mqtt_cs_connected = ConnectionState.mqtt_cs_connected
mqtt_cs_disconnecting = ConnectionState.mqtt_cs_disconnecting
mqtt_cs_connect_async = ConnectionState.mqtt_cs_connect_async
mqtt_cs_new = ConnectionState.MQTT_CS_NEW
mqtt_cs_connected = ConnectionState.MQTT_CS_CONNECTED
mqtt_cs_disconnecting = ConnectionState.MQTT_CS_DISCONNECTING
mqtt_cs_connect_async = ConnectionState.MQTT_CS_CONNECT_ASYNC

# Message state
mqtt_ms_invalid = MessageState.mqtt_ms_invalid
mqtt_ms_publish = MessageState.mqtt_ms_publish
mqtt_ms_wait_for_puback = MessageState.mqtt_ms_wait_for_puback
mqtt_ms_wait_for_pubrec = MessageState.mqtt_ms_wait_for_pubrec
mqtt_ms_resend_pubrel = MessageState.mqtt_ms_resend_pubrel
mqtt_ms_wait_for_pubrel = MessageState.mqtt_ms_wait_for_pubrel
mqtt_ms_resend_pubcomp = MessageState.mqtt_ms_resend_pubcomp
mqtt_ms_wait_for_pubcomp = MessageState.mqtt_ms_wait_for_pubcomp
mqtt_ms_send_pubrec = MessageState.mqtt_ms_send_pubrec
mqtt_ms_queued = MessageState.mqtt_ms_queued
mqtt_ms_invalid = MessageState.MQTT_MS_INVALID
mqtt_ms_publish = MessageState.MQTT_MS_PUBLISH
mqtt_ms_wait_for_puback = MessageState.MQTT_MS_WAIT_FOR_PUBACK
mqtt_ms_wait_for_pubrec = MessageState.MQTT_MS_WAIT_FOR_PUBREC
mqtt_ms_resend_pubrel = MessageState.MQTT_MS_RESEND_PUBREL
mqtt_ms_wait_for_pubrel = MessageState.MQTT_MS_WAIT_FOR_PUBREL
mqtt_ms_resend_pubcomp = MessageState.MQTT_MS_RESEND_PUBCOMP
mqtt_ms_wait_for_pubcomp = MessageState.MQTT_MS_WAIT_FOR_PUBCOMP
mqtt_ms_send_pubrec = MessageState.MQTT_MS_SEND_PUBREC
mqtt_ms_queued = MessageState.MQTT_MS_QUEUED

MQTT_ERR_AGAIN = MQTTErrorCode.MQTT_ERR_AGAIN
MQTT_ERR_SUCCESS = MQTTErrorCode.MQTT_ERR_SUCCESS
Expand Down Expand Up @@ -636,7 +635,7 @@ def __init__(
self._transport = transport.lower()
self._protocol = protocol
self._userdata = userdata
self._sock: socket.socket | WebsocketWrapper | ssl.SSLSocket | None = None
self._sock: SocketLike | None = None
self._sockpairR: socket.socket | None = None
self._sockpairW: socket.socket | None = None
self._keepalive = 60
Expand Down Expand Up @@ -3969,13 +3968,20 @@ def _get_proxy(self) -> dict[str, Any] | None:
return None

def _create_socket(self) -> SocketLike:
tcp_sock = self._create_socket_connection()
sock = self._create_ssl_socket_if_enable(tcp_sock)
sock = self._create_socket_connection()
if self._ssl:
sock = self._ssl_wrap_socket(sock)

if self._transport == "websockets":
sock.settimeout(self._keepalive)
return WebsocketWrapper(sock, self._host, self._port, self._ssl,
self._websocket_path, self._websocket_extra_headers)
return WebsocketWrapper(
socket=sock,
host=self._host,
port=self._port,
is_ssl=self._ssl,
path=self._websocket_path,
extra_headers=self._websocket_extra_headers,
)

return sock

Expand All @@ -3989,10 +3995,7 @@ def _create_socket_connection(self) -> _socket.socket:
else:
return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source)

def _create_ssl_socket_if_enable(self, tcp_sock: _socket.socket) -> _socket.socket|ssl.SSLSocket:
if not self._ssl:
return tcp_sock

def _ssl_wrap_socket(self, tcp_sock: _socket.socket) -> ssl.SSLSocket:
if self._ssl_context is None:
raise ValueError(
"Impossible condition. _ssl_context should never be None if _ssl is True"
Expand All @@ -4017,8 +4020,7 @@ def _create_ssl_socket_if_enable(self, tcp_sock: _socket.socket) -> _socket.sock
)
else:
# If SSL context has already checked hostname, then don't need to do it again
if (hasattr(self._ssl_context, 'check_hostname') and
self._ssl_context.check_hostname): # type: ignore
if getattr(self._ssl_context, 'check_hostname', False): # type: ignore
verify_host = False

ssl_sock.settimeout(self._keepalive)
Expand Down
28 changes: 14 additions & 14 deletions src/paho/mqtt/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,23 @@ class ConnackCode(enum.IntEnum):


class ConnectionState(enum.IntEnum):
mqtt_cs_new = 0
mqtt_cs_connected = 1
mqtt_cs_disconnecting = 2
mqtt_cs_connect_async = 3
MQTT_CS_NEW = 0
MQTT_CS_CONNECTED = 1
MQTT_CS_DISCONNECTING = 2
MQTT_CS_CONNECT_ASYNC = 3


class MessageState(enum.IntEnum):
mqtt_ms_invalid = 0
mqtt_ms_publish = 1
mqtt_ms_wait_for_puback = 2
mqtt_ms_wait_for_pubrec = 3
mqtt_ms_resend_pubrel = 4
mqtt_ms_wait_for_pubrel = 5
mqtt_ms_resend_pubcomp = 6
mqtt_ms_wait_for_pubcomp = 7
mqtt_ms_send_pubrec = 8
mqtt_ms_queued = 9
MQTT_MS_INVALID = 0
MQTT_MS_PUBLISH = 1
MQTT_MS_WAIT_FOR_PUBACK = 2
MQTT_MS_WAIT_FOR_PUBREC = 3
MQTT_MS_RESEND_PUBREL = 4
MQTT_MS_WAIT_FOR_PUBREL = 5
MQTT_MS_RESEND_PUBCOMP = 6
MQTT_MS_WAIT_FOR_PUBCOMP = 7
MQTT_MS_SEND_PUBREC = 8
MQTT_MS_QUEUED = 9


class PahoClientMode(enum.IntEnum):
Expand Down
5 changes: 2 additions & 3 deletions src/paho/mqtt/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
from __future__ import annotations

import collections
import typing
from collections.abc import Iterable
from typing import Any, List, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Tuple, Union

from .. import mqtt
from . import client as paho

if typing.TYPE_CHECKING:
if TYPE_CHECKING:
try:
from typing import NotRequired, Required, TypedDict # type: ignore
except ImportError:
Expand Down

0 comments on commit 50b280c

Please sign in to comment.