diff --git a/src/solana/rpc/websocket_api.py b/src/solana/rpc/websocket_api.py index d97c5edc..94f62812 100644 --- a/src/solana/rpc/websocket_api.py +++ b/src/solana/rpc/websocket_api.py @@ -406,7 +406,7 @@ def _process_rpc_response(self, raw: str) -> List[Union[Notification, Subscripti return cast(List[Union[Notification, SubscriptionResult]], parsed) -class connect(ws_connect): # pylint: disable=invalid-name,too-few-public-methods +class connect(ws_connect): # pylint: disable=invalid-name """Solana RPC websocket connector.""" def __init__(self, uri: str = "ws://localhost:8900", **kwargs: Any) -> None: @@ -416,4 +416,11 @@ def __init__(self, uri: str = "ws://localhost:8900", **kwargs: Any) -> None: uri: The websocket endpoint. **kwargs: Keyword arguments for ``websockets.legacy.client.connect`` """ - super().__init__(uri, **kwargs, create_protocol=SolanaWsClientProtocol) + # Ensure that create_protocol explicitly creates a SolanaWsClientProtocol + kwargs.setdefault("create_protocol", SolanaWsClientProtocol) + super().__init__(uri, **kwargs) + + async def __aenter__(self) -> SolanaWsClientProtocol: + """Overrides to specify the type of protocol explicitly.""" + protocol = await super().__aenter__() + return cast(SolanaWsClientProtocol, protocol)