Skip to content

Commit

Permalink
Fix: AlephClient class could not use unix sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
hoh committed May 16, 2023
1 parent cd49ef8 commit 6148040
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 8 deletions.
58 changes: 50 additions & 8 deletions src/aleph/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
MultipleMessagesError,
)
from .models import MessagesResponse
from .utils import get_message_type_value
from .utils import check_unix_socket_valid, get_message_type_value

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -94,14 +94,14 @@ def func_caller(*args, **kwargs):


async def run_async_watcher(
*args, output_queue: queue.Queue, api_server: str, **kwargs
*args, output_queue: queue.Queue, api_server: Optional[str], **kwargs
):
async with AlephClient(api_server=api_server) as session:
async for message in session.watch_messages(*args, **kwargs):
output_queue.put(message)


def watcher_thread(output_queue: queue.Queue, api_server: str, args, kwargs):
def watcher_thread(output_queue: queue.Queue, api_server: Optional[str], args, kwargs):
asyncio.run(
run_async_watcher(
output_queue=output_queue, api_server=api_server, *args, **kwargs
Expand Down Expand Up @@ -443,9 +443,39 @@ class AlephClient:
api_server: str
http_session: aiohttp.ClientSession

def __init__(self, api_server: str):
self.api_server = api_server
self.http_session = aiohttp.ClientSession(base_url=api_server)
def __init__(
self,
api_server: Optional[str],
api_unix_socket: Optional[str] = None,
allow_unix_sockets: bool = True,
timeout: Optional[aiohttp.ClientTimeout] = None,
):
"""AlephClient can use HTTP(S) or HTTP over Unix sockets.
Unix sockets are used when running inside a virtual machine,
and can be shared across containers in a more secure way than TCP ports.
"""
self.api_server = api_server or settings.API_HOST
if not self.api_server:
raise ValueError("Missing API host")

unix_socket_path = api_unix_socket or settings.API_UNIX_SOCKET
if unix_socket_path and allow_unix_sockets:
check_unix_socket_valid(unix_socket_path)
connector = aiohttp.UnixConnector(path=unix_socket_path)
else:
connector = None

# ClientSession timeout defaults to a private sentinel object and may not be None.
self.http_session = (
aiohttp.ClientSession(
base_url=self.api_server, connector=connector, timeout=timeout
)
if timeout
else aiohttp.ClientSession(
base_url=self.api_server,
connector=connector,
)
)

def __enter__(self) -> UserSessionSync:
return UserSessionSync(async_session=self)
Expand Down Expand Up @@ -825,8 +855,20 @@ class AuthenticatedAlephClient(AlephClient):
"channel",
}

def __init__(self, account: Account, api_server: str):
super().__init__(api_server=api_server)
def __init__(
self,
account: Account,
api_server: Optional[str],
api_unix_socket: Optional[str] = None,
allow_unix_sockets: bool = True,
timeout: Optional[aiohttp.ClientTimeout] = None,
):
super().__init__(
api_server=api_server,
api_unix_socket=api_unix_socket,
allow_unix_sockets=allow_unix_sockets,
timeout=timeout,
)
self.account = account

def __enter__(self) -> "AuthenticatedUserSessionSync":
Expand Down
17 changes: 17 additions & 0 deletions src/aleph/sdk/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import errno
import logging
import os
from pathlib import Path
Expand Down Expand Up @@ -59,3 +60,19 @@ def get_message_type_value(message_type: Type[GenericMessage]) -> MessageType:
"""Returns the value of the 'type' field of a message type class."""
type_literal = message_type.__annotations__["type"]
return type_literal.__args__[0] # Get the value from a Literal


def check_unix_socket_valid(unix_socket_path: str) -> bool:
"""Check that a unix socket exists at the given path, or raise a FileNotFoundError."""
path = Path(unix_socket_path)
if not path.exists():
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), unix_socket_path
)
if not path.is_socket():
raise FileNotFoundError(
errno.ENOTSOCK,
os.strerror(errno.ENOENT),
unix_socket_path,
)
return True

0 comments on commit 6148040

Please sign in to comment.