diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py index bdd78c23..9877d716 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client.py @@ -54,7 +54,7 @@ MultipleMessagesError, ) from .models import MessagesResponse -from .utils import get_message_type_value +from .utils import check_unix_socket_valid, get_message_type_value logger = logging.getLogger(__name__) @@ -94,14 +94,14 @@ def func_caller(*args, **kwargs): async def run_async_watcher( - *args, output_queue: queue.Queue, api_server: str, **kwargs + *args, output_queue: queue.Queue, api_server: Optional[str], **kwargs ): async with AlephClient(api_server=api_server) as session: async for message in session.watch_messages(*args, **kwargs): output_queue.put(message) -def watcher_thread(output_queue: queue.Queue, api_server: str, args, kwargs): +def watcher_thread(output_queue: queue.Queue, api_server: Optional[str], args, kwargs): asyncio.run( run_async_watcher( output_queue=output_queue, api_server=api_server, *args, **kwargs @@ -443,9 +443,39 @@ class AlephClient: api_server: str http_session: aiohttp.ClientSession - def __init__(self, api_server: str): - self.api_server = api_server - self.http_session = aiohttp.ClientSession(base_url=api_server) + def __init__( + self, + api_server: Optional[str], + api_unix_socket: Optional[str] = None, + allow_unix_sockets: bool = True, + timeout: Optional[aiohttp.ClientTimeout] = None, + ): + """AlephClient can use HTTP(S) or HTTP over Unix sockets. + Unix sockets are used when running inside a virtual machine, + and can be shared across containers in a more secure way than TCP ports. + """ + self.api_server = api_server or settings.API_HOST + if not self.api_server: + raise ValueError("Missing API host") + + unix_socket_path = api_unix_socket or settings.API_UNIX_SOCKET + if unix_socket_path and allow_unix_sockets: + check_unix_socket_valid(unix_socket_path) + connector = aiohttp.UnixConnector(path=unix_socket_path) + else: + connector = None + + # ClientSession timeout defaults to a private sentinel object and may not be None. + self.http_session = ( + aiohttp.ClientSession( + base_url=self.api_server, connector=connector, timeout=timeout + ) + if timeout + else aiohttp.ClientSession( + base_url=self.api_server, + connector=connector, + ) + ) def __enter__(self) -> UserSessionSync: return UserSessionSync(async_session=self) @@ -825,8 +855,20 @@ class AuthenticatedAlephClient(AlephClient): "channel", } - def __init__(self, account: Account, api_server: str): - super().__init__(api_server=api_server) + def __init__( + self, + account: Account, + api_server: Optional[str], + api_unix_socket: Optional[str] = None, + allow_unix_sockets: bool = True, + timeout: Optional[aiohttp.ClientTimeout] = None, + ): + super().__init__( + api_server=api_server, + api_unix_socket=api_unix_socket, + allow_unix_sockets=allow_unix_sockets, + timeout=timeout, + ) self.account = account def __enter__(self) -> "AuthenticatedUserSessionSync": diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index 4290bcf6..fdbf6095 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -1,3 +1,4 @@ +import errno import logging import os from pathlib import Path @@ -59,3 +60,19 @@ def get_message_type_value(message_type: Type[GenericMessage]) -> MessageType: """Returns the value of the 'type' field of a message type class.""" type_literal = message_type.__annotations__["type"] return type_literal.__args__[0] # Get the value from a Literal + + +def check_unix_socket_valid(unix_socket_path: str) -> bool: + """Check that a unix socket exists at the given path, or raise a FileNotFoundError.""" + path = Path(unix_socket_path) + if not path.exists(): + raise FileNotFoundError( + errno.ENOENT, os.strerror(errno.ENOENT), unix_socket_path + ) + if not path.is_socket(): + raise FileNotFoundError( + errno.ENOTSOCK, + os.strerror(errno.ENOENT), + unix_socket_path, + ) + return True