Skip to content

Commit

Permalink
add support for "unix" transport where socket module contains AF_UNIX
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrej Krpic committed Mar 25, 2024
1 parent f544b4e commit fc70a41
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 24 deletions.
25 changes: 20 additions & 5 deletions src/paho/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,10 @@ class Client:
:param transport: use "websockets" to use WebSockets as the transport
mechanism. Set to "tcp" to use raw TCP, which is the default.
Use "unix" to use Unix sockets as the transport mechanism; note that
this option is only available on platforms that support Unix sockets,
and the "host" argument is interpreted as the path to the Unix socket
file in this case.
:param bool manual_ack: normally, when a message is received, the library automatically
acknowledges after on_message callback returns. manual_ack=True allows the application to
Expand Down Expand Up @@ -733,14 +737,16 @@ def __init__(
clean_session: bool | None = None,
userdata: Any = None,
protocol: MQTTProtocolVersion = MQTTv311,
transport: Literal["tcp", "websockets"] = "tcp",
transport: Literal["tcp", "websockets", "unix"] = "tcp",
reconnect_on_failure: bool = True,
manual_ack: bool = False,
) -> None:
transport = transport.lower() # type: ignore
if transport not in ("websockets", "tcp"):
if transport == "unix" and not hasattr(socket, "AF_UNIX"):
raise ValueError('"unix" transport not supported')
elif transport not in ("websockets", "tcp", "unix"):
raise ValueError(
f'transport must be "websockets" or "tcp", not {transport}')
f'transport must be "websockets", "tcp" or "unix", not {transport}')

self._manual_ack = manual_ack
self._transport = transport
Expand Down Expand Up @@ -931,7 +937,7 @@ def keepalive(self, value: int) -> None:
self._keepalive = value

@property
def transport(self) -> Literal["tcp", "websockets"]:
def transport(self) -> Literal["tcp", "websockets", "unix"]:
"""
Transport method used for the connection ("tcp" or "websockets").
Expand Down Expand Up @@ -4595,7 +4601,11 @@ def _get_proxy(self) -> dict[str, Any] | None:
return None

def _create_socket(self) -> SocketLike:
sock = self._create_socket_connection()
if self._transport == "unix":
sock = self._create_unix_socket_connection()
else:
sock = self._create_socket_connection()

if self._ssl:
sock = self._ssl_wrap_socket(sock)

Expand All @@ -4612,6 +4622,11 @@ def _create_socket(self) -> SocketLike:

return sock

def _create_unix_socket_connection(self) -> _socket.socket:
unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
unix_socket.connect(self._host)
return unix_socket

def _create_socket_connection(self) -> _socket.socket:
proxy = self._get_proxy()
addr = (self._host, self._port)
Expand Down
29 changes: 20 additions & 9 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_01_con_discon_success(self, proto_ver, callback_version, fake_broker):
callback_version,
"01-con-discon-success",
protocol=proto_ver,
transport=fake_broker.transport,
)

def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None):
Expand Down Expand Up @@ -70,7 +71,8 @@ def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None):

def test_01_con_failure_rc(self, proto_ver, callback_version, fake_broker):
mqttc = client.Client(
callback_version, "01-con-failure-rc", protocol=proto_ver)
callback_version, "01-con-failure-rc",
protocol=proto_ver, transport=fake_broker.transport)

def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None):
assert rc_or_reason_code > 0
Expand Down Expand Up @@ -107,7 +109,9 @@ def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None):
mqttc.loop_stop()

def test_connection_properties(self, proto_ver, callback_version, fake_broker):
mqttc = client.Client(CallbackAPIVersion.VERSION2, "client-id", protocol=proto_ver)
mqttc = client.Client(
CallbackAPIVersion.VERSION2, "client-id",
protocol=proto_ver, transport=fake_broker.transport)
mqttc.enable_logger()

is_connected = threading.Event()
Expand All @@ -131,7 +135,7 @@ def on_disconnect(*args):
mqttc.keepalive = 7
mqttc.max_inflight_messages = 7
mqttc.max_queued_messages = 7
mqttc.transport = "tcp"
mqttc.transport = fake_broker.transport
mqttc.username = "username"
mqttc.password = "password"

Expand Down Expand Up @@ -184,7 +188,7 @@ def on_disconnect(*args):
mqttc.max_queued_messages = 7

with pytest.raises(RuntimeError):
mqttc.transport = "tcp"
mqttc.transport = fake_broker.transport

with pytest.raises(RuntimeError):
mqttc.username = "username"
Expand Down Expand Up @@ -217,7 +221,9 @@ class Test_connect_v5:
"""

def test_01_broker_no_support(self, fake_broker):
mqttc = client.Client(CallbackAPIVersion.VERSION2, "01-broker-no-support", protocol=MQTTProtocolVersion.MQTTv5)
mqttc = client.Client(
CallbackAPIVersion.VERSION2, "01-broker-no-support",
protocol=MQTTProtocolVersion.MQTTv5, transport=fake_broker.transport)

def on_connect(mqttc, obj, flags, reason, properties):
assert reason == 132
Expand Down Expand Up @@ -261,6 +267,7 @@ def test_with_loop_start(self, fake_broker: FakeBroker):
"test_with_loop_start",
protocol=MQTTProtocolVersion.MQTTv311,
reconnect_on_failure=False,
transport=fake_broker.transport
)

on_connect_reached = threading.Event()
Expand Down Expand Up @@ -311,6 +318,7 @@ def test_with_loop(self, fake_broker: FakeBroker):
CallbackAPIVersion.VERSION1,
"test_with_loop",
clean_session=True,
transport=fake_broker.transport,
)

on_connect_reached = threading.Event()
Expand Down Expand Up @@ -367,6 +375,7 @@ def test_publish_before_connect(self, fake_broker: FakeBroker) -> None:
mqttc = client.Client(
CallbackAPIVersion.VERSION1,
"test_publish_before_connect",
transport=fake_broker.transport,
)

def on_connect(mqttc, obj, flags, rc):
Expand Down Expand Up @@ -424,7 +433,7 @@ def on_connect(mqttc, obj, flags, rc):
])
class TestPublishBroker2Client:
def test_invalid_utf8_topic(self, callback_version, fake_broker):
mqttc = client.Client(callback_version, "client-id")
mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport)

def on_message(client, userdata, msg):
with pytest.raises(UnicodeDecodeError):
Expand Down Expand Up @@ -466,7 +475,7 @@ def on_message(client, userdata, msg):
assert not packet_in # Check connection is closed

def test_valid_utf8_topic_recv(self, callback_version, fake_broker):
mqttc = client.Client(callback_version, "client-id")
mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport)

# It should be non-ascii multi-bytes character
topic = unicodedata.lookup('SNOWMAN')
Expand Down Expand Up @@ -512,7 +521,7 @@ def on_message(client, userdata, msg):
assert not packet_in # Check connection is closed

def test_valid_utf8_topic_publish(self, callback_version, fake_broker):
mqttc = client.Client(callback_version, "client-id")
mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport)

# It should be non-ascii multi-bytes character
topic = unicodedata.lookup('SNOWMAN')
Expand Down Expand Up @@ -558,7 +567,7 @@ def test_valid_utf8_topic_publish(self, callback_version, fake_broker):
assert not packet_in # Check connection is closed

def test_message_callback(self, callback_version, fake_broker):
mqttc = client.Client(callback_version, "client-id")
mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport)
userdata = {
'on_message': 0,
'callback1': 0,
Expand Down Expand Up @@ -698,6 +707,7 @@ def test_callback_v1_mqtt3(self, fake_broker):
CallbackAPIVersion.VERSION1,
"client-id",
userdata=callback_called,
transport=fake_broker.transport,
)

def on_connect(cl, userdata, flags, rc):
Expand Down Expand Up @@ -823,6 +833,7 @@ def test_callback_v2_mqtt3(self, fake_broker):
CallbackAPIVersion.VERSION2,
"client-id",
userdata=callback_called,
transport=fake_broker.transport,
)

def on_connect(cl, userdata, flags, reason, properties):
Expand Down
36 changes: 26 additions & 10 deletions tests/testsupport/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,35 @@
import socket
import socketserver
import threading
import os

import pytest

from tests import paho_test


class FakeBroker:
def __init__(self):
# Bind to "localhost" for maximum performance, as described in:
# http://docs.python.org/howto/sockets.html#ipc
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
def __init__(self, transport):
if transport == "tcp":
# Bind to "localhost" for maximum performance, as described in:
# http://docs.python.org/howto/sockets.html#ipc
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", 0))
self.port = sock.getsockname()[1]
elif transport == "unix":
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.bind("localhost")
self.port = 1883
else:
raise ValueError(f"unsupported transport {transport}")

sock.settimeout(5)
sock.bind(("localhost", 0))
self.port = sock.getsockname()[1]
sock.listen(1)

self._sock = sock
self._conn = None
self.transport = transport

def start(self):
if self._sock is None:
Expand All @@ -39,6 +49,12 @@ def finish(self):
self._sock.close()
self._sock = None

if self.transport == 'unix':
try:
os.unlink('localhost')
except OSError:
pass

def receive_packet(self, num_bytes):
if self._conn is None:
raise ValueError('Connection is not open')
Expand All @@ -60,10 +76,10 @@ def expect_packet(self, name, packet):
paho_test.expect_packet(self._conn, name, packet)


@pytest.fixture
def fake_broker():
@pytest.fixture(params=["tcp"] + (["unix"] if hasattr(socket, 'AF_UNIX') else []))
def fake_broker(request):
# print('Setup broker')
broker = FakeBroker()
broker = FakeBroker(request.param)

yield broker

Expand Down

0 comments on commit fc70a41

Please sign in to comment.