diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index f1639f75b08..2457cd540a2 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2747,6 +2747,7 @@ async def wrapped_callback(): # return payment_key so this branch will not be executed again return None, payment_key, None elif preimage: + self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc) return preimage, None, None else: # we are waiting for mpp consolidation or preimage @@ -2758,7 +2759,10 @@ async def wrapped_callback(): preimage = self.lnworker.get_preimage(payment_hash) error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key) if error_bytes or error_reason or preimage: - self.lnworker.maybe_cleanup_forwarding(payment_key, chan.get_scid_or_local_alias(), htlc) + cleanup_keys = self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc) + is_htlc_key = ':' in payment_key + if is_htlc_key or payment_key in cleanup_keys: + self.lnworker.maybe_cleanup_forwarding(payment_key) if error_bytes: return None, None, error_bytes if error_reason: diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 2b4e24df33d..67afa3e0916 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -87,6 +87,7 @@ from .channel_db import ChannelInfo, Policy from .mpp_split import suggest_splits, SplitConfigRating from .trampoline import create_trampoline_route_and_onion, is_legacy_relay +from .json_db import stored_in if TYPE_CHECKING: from .network import Network @@ -169,11 +170,11 @@ class PaymentInfo(NamedTuple): status: int -class RecvMPPResolution(Enum): - WAITING = enum.auto() - EXPIRED = enum.auto() - ACCEPTED = enum.auto() - FAILED = enum.auto() +class RecvMPPResolution(IntEnum): + WAITING = 0 + EXPIRED = 1 + ACCEPTED = 2 + FAILED = 3 class ReceivedMPPStatus(NamedTuple): @@ -181,6 +182,13 @@ class ReceivedMPPStatus(NamedTuple): expected_msat: int htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] + @stored_in('received_mpp_htlcs', tuple) + def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus': + htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid,x) in htlc_list]) + return ReceivedMPPStatus( + resolution=RecvMPPResolution(resolution), + expected_msat=expected_msat, + htlc_set=htlc_set) SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id @@ -851,7 +859,7 @@ def __init__(self, wallet: 'Abstract_Wallet', xprv): self._paysessions = dict() # type: Dict[bytes, PaySession] self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo] - self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus + self.received_mpp_htlcs = self.db.get_dict('received_mpp_htlcs') # type: Dict[str, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus # detect inflight payments self.inflight_payments = set() # (not persisted) keys of invoices that are in PR_INFLIGHT state @@ -2235,7 +2243,7 @@ def bundle_payments(self, hash_list): payment_keys = [self._get_payment_key(x) for x in hash_list] self.payment_bundles.append(payment_keys) - def get_payment_bundle(self, payment_key): + def get_payment_bundle(self, payment_key: bytes) -> Sequence[bytes]: for key_list in self.payment_bundles: if payment_key in key_list: return key_list @@ -2302,7 +2310,7 @@ def check_mpp_status( payment_key = payment_hash + payment_secret self.update_mpp_with_received_htlc( payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat) - mpp_resolution = self.received_mpp_htlcs[payment_key].resolution + mpp_resolution = self.received_mpp_htlcs[payment_key.hex()].resolution # if still waiting, calc resolution now: if mpp_resolution == RecvMPPResolution.WAITING: bundle = self.get_payment_bundle(payment_key) @@ -2323,7 +2331,7 @@ def check_mpp_status( # save resolution, if any. if mpp_resolution != RecvMPPResolution.WAITING: for pkey in payment_keys: - if pkey in self.received_mpp_htlcs: + if pkey.hex() in self.received_mpp_htlcs: self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution) return mpp_resolution @@ -2337,7 +2345,7 @@ def update_mpp_with_received_htlc( expected_msat: int, ): # add new htlc to set - mpp_status = self.received_mpp_htlcs.get(payment_key) + mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) if mpp_status is None: mpp_status = ReceivedMPPStatus( resolution=RecvMPPResolution.WAITING, @@ -2351,47 +2359,46 @@ def update_mpp_with_received_htlc( key = (scid, htlc) if key not in mpp_status.htlc_set: mpp_status.htlc_set.add(key) # side-effecting htlc_set - self.received_mpp_htlcs[payment_key] = mpp_status + self.received_mpp_htlcs[payment_key.hex()] = mpp_status def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution): - mpp_status = self.received_mpp_htlcs[payment_key] - self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=resolution) + mpp_status = self.received_mpp_htlcs[payment_key.hex()] + self.logger.info(f'set_mpp_resolution {resolution.name} {len(mpp_status.htlc_set)}') + self.received_mpp_htlcs[payment_key.hex()] = mpp_status._replace(resolution=resolution) def is_mpp_amount_reached(self, payment_key: bytes) -> bool: - mpp_status = self.received_mpp_htlcs.get(payment_key) + mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) if not mpp_status: return False total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set]) return total >= mpp_status.expected_msat def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int: - mpp_status = self.received_mpp_htlcs.get(payment_key) + mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) if not mpp_status: return int(time.time()) return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set]) - def maybe_cleanup_forwarding( + def maybe_cleanup_mpp( self, - payment_key_hex: str, short_channel_id: ShortChannelID, htlc: UpdateAddHtlc, - ) -> None: - - is_htlc_key = ':' in payment_key_hex - if not is_htlc_key: - payment_key = bytes.fromhex(payment_key_hex) - mpp_status = self.received_mpp_htlcs.get(payment_key) - if not mpp_status or mpp_status.resolution == RecvMPPResolution.WAITING: - # After restart, self.received_mpp_htlcs needs to be reconstructed - self.logger.info(f'maybe_cleanup_forwarding: mpp_status not ready') - return - htlc_key = (short_channel_id, htlc) + ) -> Sequence[str]: + htlc_key = (short_channel_id, htlc) + cleanup_keys = [] + for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()): + if htlc_key not in mpp_status.htlc_set: + continue + assert mpp_status.resolution != RecvMPPResolution.WAITING + self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP') mpp_status.htlc_set.remove(htlc_key) # side-effecting htlc_set - if mpp_status.htlc_set: - return - self.logger.info('cleaning up mpp') - self.received_mpp_htlcs.pop(payment_key) + if len(mpp_status.htlc_set) == 0: + self.logger.info(f'maybe_cleanup_mpp: removing MPP') + self.received_mpp_htlcs.pop(payment_key_hex) + cleanup_keys.append(payment_key_hex) + return cleanup_keys + def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None: self.active_forwardings.pop(payment_key_hex, None) self.forwarding_failures.pop(payment_key_hex, None) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index fa0b8c20638..3342aea404d 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -312,6 +312,7 @@ async def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: Ln save_forwarding_failure = LNWallet.save_forwarding_failure get_forwarding_failure = LNWallet.get_forwarding_failure maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding + maybe_cleanup_mpp = LNWallet.maybe_cleanup_mpp class MockTransport: @@ -1727,6 +1728,7 @@ async def pay( ): alice_w = graph.workers['alice'] bob_w = graph.workers['bob'] + carol_w = graph.workers['carol'] dave_w = graph.workers['dave'] if mpp_invoice: dave_w.features |= LnFeatures.BASIC_MPP_OPT @@ -1748,6 +1750,12 @@ async def pay( await asyncio.sleep(2) if result: self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash)) + # check mpp is cleaned up + async with OldTaskGroup() as g: + for peer in peers: + await g.spawn(peer.wait_one_htlc_switch_iteration()) + for peer in peers: + self.assertEqual(len(peer.lnworker.received_mpp_htlcs), 0) raise PaymentDone() elif len(log) == 1 and log[0].failure_msg.code == OnionFailureCode.MPP_TIMEOUT: raise PaymentTimeout()