diff --git a/src/solana/rpc/websocket_api.py b/src/solana/rpc/websocket_api.py index d97c5edc..f6b3fd2c 100644 --- a/src/solana/rpc/websocket_api.py +++ b/src/solana/rpc/websocket_api.py @@ -416,4 +416,11 @@ def __init__(self, uri: str = "ws://localhost:8900", **kwargs: Any) -> None: uri: The websocket endpoint. **kwargs: Keyword arguments for ``websockets.legacy.client.connect`` """ - super().__init__(uri, **kwargs, create_protocol=SolanaWsClientProtocol) + # Ensure that create_protocol explicitly creates a SolanaWsClientProtocol + kwargs.setdefault("create_protocol", SolanaWsClientProtocol) + super().__init__(uri, **kwargs) + + async def __aenter__(self) -> SolanaWsClientProtocol: + """Overrides to specify the type of protocol explicitly.""" + protocol = await super().__aenter__() + return cast(SolanaWsClientProtocol, protocol)