Skip to content

Commit

Permalink
monkey patching helper
Browse files Browse the repository at this point in the history
  • Loading branch information
altendky committed Oct 31, 2024
1 parent fd442a6 commit 10b0543
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 55 deletions.
19 changes: 19 additions & 0 deletions chia/_tests/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from chia._tests.core.data_layer.util import ChiaRoot
from chia._tests.util.time_out_assert import DataTypeProtocol, caller_file_and_line
from chia.full_node.mempool import Mempool
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.api_protocol import ApiMetadata, ApiProtocol, api_attribute_name
from chia.server.outbound_message import Message
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.condition_opcodes import ConditionOpcode
from chia.util.hash import std_hash
Expand Down Expand Up @@ -630,3 +633,19 @@ def __ge__(self: T_ComparableEnum, other: T_ComparableEnum) -> object:
return NotImplemented

return self.value.__ge__(other.value)


@contextlib.contextmanager
def patch_request_handler(
api: type[ApiProtocol], handler: Callable[..., Awaitable[Optional[Message]]]
) -> Iterator[None]:
message_type = ProtocolMessageTypes[handler.__name__]

api_metadata = ApiMetadata()
wrapped = api_metadata.request()(handler)

with pytest.MonkeyPatch().context() as m:
m.setattr(wrapped, api_attribute_name, api.api)
m.setitem(api.api.message_type_to_request, message_type, api_metadata.message_type_to_request[message_type])

yield
97 changes: 42 additions & 55 deletions chia/_tests/wallet/test_wallet_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import pytest
from chia_rs import G1Element, PrivateKey

from chia._tests.util.misc import CoinGenerator
from chia._tests.util.misc import CoinGenerator, patch_request_handler
from chia._tests.util.setup_nodes import OldSimulatorsAndWallets
from chia._tests.util.time_out_assert import time_out_assert
from chia.full_node.full_node_api import FullNodeAPI
from chia.protocols import wallet_protocol
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.wallet_protocol import CoinState
from chia.server.api_protocol import ApiMetadata, Self
from chia.server.api_protocol import Self
from chia.server.outbound_message import Message, make_msg
from chia.simulator.add_blocks_in_batches import add_blocks_in_batches
from chia.simulator.block_tools import test_constants
Expand Down Expand Up @@ -623,58 +624,52 @@ async def test_transaction_send_cache(
# Replacing the normal logic a full node has for processing transactions with a function that just logs what it gets
logged_spends = []

api = ApiMetadata()

@api.request()
async def send_transaction(
self: Self, request: wallet_protocol.SendTransaction, *, test: bool = False
) -> Optional[Message]:
logged_spends.append(request.transaction.name())
return None

assert full_node_api.full_node._server is not None
monkeypatch.setattr(
full_node_api.full_node._server.get_connections()[0].api,
"send_transaction",
types.MethodType(send_transaction, full_node_api.full_node._server.get_connections()[0].api),
)
with patch_request_handler(api=FullNodeAPI, handler=send_transaction):
# Generate the transaction
async with wallet.wallet_state_manager.new_action_scope(DEFAULT_TX_CONFIG, push=True) as action_scope:
await wallet.generate_signed_transaction(uint64(0), bytes32.zeros, action_scope)
[tx] = action_scope.side_effects.transactions

# Generate the transaction
async with wallet.wallet_state_manager.new_action_scope(DEFAULT_TX_CONFIG, push=True) as action_scope:
await wallet.generate_signed_transaction(uint64(0), bytes32.zeros, action_scope)
[tx] = action_scope.side_effects.transactions
# Make sure it is sent to the peer
await wallet_node._resend_queue()

# Make sure it is sent to the peer
await wallet_node._resend_queue()
def logged_spends_len() -> int:
return len(logged_spends)

def logged_spends_len() -> int:
return len(logged_spends)
await time_out_assert(5, logged_spends_len, 1)

await time_out_assert(5, logged_spends_len, 1)
# Make sure queue processing again does not result in another spend
await wallet_node._resend_queue()
with pytest.raises(AssertionError):
await time_out_assert(5, logged_spends_len, 2)

# Make sure queue processing again does not result in another spend
await wallet_node._resend_queue()
with pytest.raises(AssertionError):
await time_out_assert(5, logged_spends_len, 2)
# Tell the wallet that we recieved the spend (but failed to process it so it should send again)
msg = make_msg(
ProtocolMessageTypes.transaction_ack,
wallet_protocol.TransactionAck(
tx.name, uint8(MempoolInclusionStatus.FAILED), Err.GENERATOR_RUNTIME_ERROR.name
),
)
assert simulator_and_wallet[1][0][0]._server is not None
await simulator_and_wallet[1][0][0]._server.get_connections()[0].incoming_queue.put(msg)

# Tell the wallet that we recieved the spend (but failed to process it so it should send again)
msg = make_msg(
ProtocolMessageTypes.transaction_ack,
wallet_protocol.TransactionAck(tx.name, uint8(MempoolInclusionStatus.FAILED), Err.GENERATOR_RUNTIME_ERROR.name),
)
assert simulator_and_wallet[1][0][0]._server is not None
await simulator_and_wallet[1][0][0]._server.get_connections()[0].incoming_queue.put(msg)
# Make sure the cache is emptied
def check_wallet_cache_empty() -> bool:
return wallet_node._tx_messages_in_progress == {}

# Make sure the cache is emptied
def check_wallet_cache_empty() -> bool:
return wallet_node._tx_messages_in_progress == {}
await time_out_assert(5, check_wallet_cache_empty, True)

await time_out_assert(5, check_wallet_cache_empty, True)
# Re-process the queue again and this time it should result in a resend
await wallet_node._resend_queue()
await time_out_assert(5, logged_spends_len, 2)
assert logged_spends == [tx.name, tx.name]

# Re-process the queue again and this time it should result in a resend
await wallet_node._resend_queue()
await time_out_assert(5, logged_spends_len, 2)
assert logged_spends == [tx.name, tx.name]
await time_out_assert(5, check_wallet_cache_empty, False)

# Disconnect from the peer to make sure their entry in the cache is also deleted
Expand All @@ -691,9 +686,6 @@ async def test_wallet_node_bad_coin_state_ignore(

await wallet_server.start_client(PeerInfo(self_hostname, full_node_api.server.get_port()), None)

api = ApiMetadata()

@api.request()
async def register_interest_in_coin(
self: Self, request: wallet_protocol.RegisterForCoinUpdates, *, test: bool = False
) -> Optional[Message]:
Expand All @@ -708,20 +700,15 @@ async def validate_received_state_from_peer(*args: Any) -> bool:
# It's an interesting case here where we don't hit this unless something is broken
return True # pragma: no cover

assert full_node_api.full_node._server is not None
monkeypatch.setattr(
full_node_api.full_node._server.get_connections()[0].api,
"register_interest_in_coin",
types.MethodType(register_interest_in_coin, full_node_api.full_node._server.get_connections()[0].api),
)
monkeypatch.setattr(
wallet_node,
"validate_received_state_from_peer",
types.MethodType(validate_received_state_from_peer, wallet_node),
)
with patch_request_handler(api=FullNodeAPI, handler=register_interest_in_coin):
monkeypatch.setattr(
wallet_node,
"validate_received_state_from_peer",
types.MethodType(validate_received_state_from_peer, wallet_node),
)

with pytest.raises(PeerRequestException):
await wallet_node.get_coin_state([], wallet_node.get_full_node_peer())
with pytest.raises(PeerRequestException):
await wallet_node.get_coin_state([], wallet_node.get_full_node_peer())


@pytest.mark.anyio
Expand Down

0 comments on commit 10b0543

Please sign in to comment.