Skip to content

Commit

Permalink
Persist MPP resolution status in wallet file.
Browse files Browse the repository at this point in the history
If we accept a MPP and we forward the payment (trampoline or swap),
we need to persist the payment accepted status, or we might wrongly
release htlcs on the next restart.

lnworker.received_mpp_htlcs used to be cleaned up in maybe_cleanup_forwarding,
which only applies to forwarded payments. However, since we now
persist this dict, we need to clean it up also in the case of
payments received by us. This part of maybe_cleanup_forwarding has
been migrated to lnworker.maybe_cleanup_mpp
  • Loading branch information
ecdsa committed Jun 7, 2024
1 parent 7a0bffc commit 04de266
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 33 deletions.
6 changes: 5 additions & 1 deletion electrum/lnpeer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2747,6 +2747,7 @@ async def wrapped_callback():
# return payment_key so this branch will not be executed again
return None, payment_key, None
elif preimage:
self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc)
return preimage, None, None
else:
# we are waiting for mpp consolidation or preimage
Expand All @@ -2758,7 +2759,10 @@ async def wrapped_callback():
preimage = self.lnworker.get_preimage(payment_hash)
error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key)
if error_bytes or error_reason or preimage:
self.lnworker.maybe_cleanup_forwarding(payment_key, chan.get_scid_or_local_alias(), htlc)
cleanup_keys = self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc)
is_htlc_key = ':' in payment_key
if is_htlc_key or payment_key in cleanup_keys:
self.lnworker.maybe_cleanup_forwarding(payment_key)
if error_bytes:
return None, None, error_bytes
if error_reason:
Expand Down
71 changes: 39 additions & 32 deletions electrum/lnworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from .channel_db import ChannelInfo, Policy
from .mpp_split import suggest_splits, SplitConfigRating
from .trampoline import create_trampoline_route_and_onion, is_legacy_relay
from .json_db import stored_in

if TYPE_CHECKING:
from .network import Network
Expand Down Expand Up @@ -169,18 +170,25 @@ class PaymentInfo(NamedTuple):
status: int


class RecvMPPResolution(Enum):
WAITING = enum.auto()
EXPIRED = enum.auto()
ACCEPTED = enum.auto()
FAILED = enum.auto()
class RecvMPPResolution(IntEnum):
WAITING = 0
EXPIRED = 1
ACCEPTED = 2
FAILED = 3


class ReceivedMPPStatus(NamedTuple):
resolution: RecvMPPResolution
expected_msat: int
htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]

@stored_in('received_mpp_htlcs', tuple)
def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus':
htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid,x) in htlc_list])
return ReceivedMPPStatus(
resolution=RecvMPPResolution(resolution),
expected_msat=expected_msat,
htlc_set=htlc_set)

SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id

Expand Down Expand Up @@ -851,7 +859,7 @@ def __init__(self, wallet: 'Abstract_Wallet', xprv):

self._paysessions = dict() # type: Dict[bytes, PaySession]
self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
self.received_mpp_htlcs = self.db.get_dict('received_mpp_htlcs') # type: Dict[str, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus

# detect inflight payments
self.inflight_payments = set() # (not persisted) keys of invoices that are in PR_INFLIGHT state
Expand Down Expand Up @@ -2235,7 +2243,7 @@ def bundle_payments(self, hash_list):
payment_keys = [self._get_payment_key(x) for x in hash_list]
self.payment_bundles.append(payment_keys)

def get_payment_bundle(self, payment_key):
def get_payment_bundle(self, payment_key: bytes) -> Sequence[bytes]:
for key_list in self.payment_bundles:
if payment_key in key_list:
return key_list
Expand Down Expand Up @@ -2302,7 +2310,7 @@ def check_mpp_status(
payment_key = payment_hash + payment_secret
self.update_mpp_with_received_htlc(
payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat)
mpp_resolution = self.received_mpp_htlcs[payment_key].resolution
mpp_resolution = self.received_mpp_htlcs[payment_key.hex()].resolution
# if still waiting, calc resolution now:
if mpp_resolution == RecvMPPResolution.WAITING:
bundle = self.get_payment_bundle(payment_key)
Expand All @@ -2323,7 +2331,7 @@ def check_mpp_status(
# save resolution, if any.
if mpp_resolution != RecvMPPResolution.WAITING:
for pkey in payment_keys:
if pkey in self.received_mpp_htlcs:
if pkey.hex() in self.received_mpp_htlcs:
self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution)

return mpp_resolution
Expand All @@ -2337,7 +2345,7 @@ def update_mpp_with_received_htlc(
expected_msat: int,
):
# add new htlc to set
mpp_status = self.received_mpp_htlcs.get(payment_key)
mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
if mpp_status is None:
mpp_status = ReceivedMPPStatus(
resolution=RecvMPPResolution.WAITING,
Expand All @@ -2351,47 +2359,46 @@ def update_mpp_with_received_htlc(
key = (scid, htlc)
if key not in mpp_status.htlc_set:
mpp_status.htlc_set.add(key) # side-effecting htlc_set
self.received_mpp_htlcs[payment_key] = mpp_status
self.received_mpp_htlcs[payment_key.hex()] = mpp_status

def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution):
mpp_status = self.received_mpp_htlcs[payment_key]
self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=resolution)
mpp_status = self.received_mpp_htlcs[payment_key.hex()]
self.logger.info(f'set_mpp_resolution {resolution.name} {len(mpp_status.htlc_set)}')
self.received_mpp_htlcs[payment_key.hex()] = mpp_status._replace(resolution=resolution)

def is_mpp_amount_reached(self, payment_key: bytes) -> bool:
mpp_status = self.received_mpp_htlcs.get(payment_key)
mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
if not mpp_status:
return False
total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set])
return total >= mpp_status.expected_msat

def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int:
mpp_status = self.received_mpp_htlcs.get(payment_key)
mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
if not mpp_status:
return int(time.time())
return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set])

def maybe_cleanup_forwarding(
def maybe_cleanup_mpp(
self,
payment_key_hex: str,
short_channel_id: ShortChannelID,
htlc: UpdateAddHtlc,
) -> None:

is_htlc_key = ':' in payment_key_hex
if not is_htlc_key:
payment_key = bytes.fromhex(payment_key_hex)
mpp_status = self.received_mpp_htlcs.get(payment_key)
if not mpp_status or mpp_status.resolution == RecvMPPResolution.WAITING:
# After restart, self.received_mpp_htlcs needs to be reconstructed
self.logger.info(f'maybe_cleanup_forwarding: mpp_status not ready')
return
htlc_key = (short_channel_id, htlc)
) -> Sequence[str]:
htlc_key = (short_channel_id, htlc)
cleanup_keys = []
for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()):
if htlc_key not in mpp_status.htlc_set:
continue
assert mpp_status.resolution != RecvMPPResolution.WAITING
self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP')
mpp_status.htlc_set.remove(htlc_key) # side-effecting htlc_set
if mpp_status.htlc_set:
return
self.logger.info('cleaning up mpp')
self.received_mpp_htlcs.pop(payment_key)
if len(mpp_status.htlc_set) == 0:
self.logger.info(f'maybe_cleanup_mpp: removing MPP')
self.received_mpp_htlcs.pop(payment_key_hex)
cleanup_keys.append(payment_key_hex)
return cleanup_keys

def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None:
self.active_forwardings.pop(payment_key_hex, None)
self.forwarding_failures.pop(payment_key_hex, None)

Expand Down
8 changes: 8 additions & 0 deletions tests/test_lnpeer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ async def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: Ln
save_forwarding_failure = LNWallet.save_forwarding_failure
get_forwarding_failure = LNWallet.get_forwarding_failure
maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding
maybe_cleanup_mpp = LNWallet.maybe_cleanup_mpp


class MockTransport:
Expand Down Expand Up @@ -1727,6 +1728,7 @@ async def pay(
):
alice_w = graph.workers['alice']
bob_w = graph.workers['bob']
carol_w = graph.workers['carol']
dave_w = graph.workers['dave']
if mpp_invoice:
dave_w.features |= LnFeatures.BASIC_MPP_OPT
Expand All @@ -1748,6 +1750,12 @@ async def pay(
await asyncio.sleep(2)
if result:
self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash))
# check mpp is cleaned up
async with OldTaskGroup() as g:
for peer in peers:
await g.spawn(peer.wait_one_htlc_switch_iteration())
for peer in peers:
self.assertEqual(len(peer.lnworker.received_mpp_htlcs), 0)
raise PaymentDone()
elif len(log) == 1 and log[0].failure_msg.code == OnionFailureCode.MPP_TIMEOUT:
raise PaymentTimeout()
Expand Down

0 comments on commit 04de266

Please sign in to comment.