diff --git a/rfq/stream.go b/rfq/stream.go index 3a23332aa..2f78c60a2 100644 --- a/rfq/stream.go +++ b/rfq/stream.go @@ -92,8 +92,12 @@ func NewStreamHandler(ctx context.Context, func (h *StreamHandler) handleIncomingWireMessage( wireMsg rfqmsg.WireMessage) error { - // Parse the wire message as an RFQ message. - msg, err := rfqmsg.NewIncomingMsgFromWire(wireMsg) + // Parse the wire message as an RFQ message. The session cache load + // function is provided to associate incoming wire messages with their + // corresponding outgoing requests during parsing. + msg, err := rfqmsg.NewIncomingMsgFromWire( + wireMsg, h.outgoingRequests.Load, + ) if err != nil { if errors.Is(err, rfqmsg.ErrUnknownMessageType) { // Silently disregard the message if we don't recognise @@ -109,66 +113,13 @@ func (h *StreamHandler) handleIncomingWireMessage( log.Debugf("Stream handling incoming message: %s", msg) - // If the incoming message is an accept message, lookup the - // corresponding outgoing request message. Assign the outgoing request - // to a field on the accept message. This step allows us to easily - // access the request that the accept message is responding to. Some of - // the request fields are not present in the accept message. - // - // If the incoming message is a reject message, remove the corresponding - // outgoing request from the store. - switch typedMsg := msg.(type) { - case *rfqmsg.Reject: - // Delete the corresponding outgoing request from the store. - h.outgoingRequests.Delete(typedMsg.ID) - - case *rfqmsg.BuyAccept: - // Load and delete the corresponding outgoing request from the - // store. - outgoingRequest, found := h.outgoingRequests.LoadAndDelete( - typedMsg.ID, - ) - - // Ensure that we have an outgoing request to match the incoming - // accept message. - if !found { - return fmt.Errorf("no outgoing request found for "+ - "incoming accept message: %s", typedMsg.ID) - } - - // Type cast the outgoing message to a BuyRequest (the request - // type that corresponds to a buy accept message). - buyReq, ok := outgoingRequest.(*rfqmsg.BuyRequest) - if !ok { - return fmt.Errorf("expected BuyRequest, got %T", - outgoingRequest) - } - - typedMsg.Request = *buyReq - - case *rfqmsg.SellAccept: - // Load and delete the corresponding outgoing request from the - // store. - outgoingRequest, found := h.outgoingRequests.LoadAndDelete( - typedMsg.ID, - ) - - // Ensure that we have an outgoing request to match the incoming - // accept message. - if !found { - return fmt.Errorf("no outgoing request found for "+ - "incoming accept message: %s", typedMsg.ID) - } - - // Type cast the outgoing message to a SellRequest (the request - // type that corresponds to a sell accept message). - req, ok := outgoingRequest.(*rfqmsg.SellRequest) - if !ok { - return fmt.Errorf("expected SellRequest, got %T", - outgoingRequest) - } - - typedMsg.Request = *req + // If the incoming message is a response to an outgoing request, we + // will remove the corresponding session from the store. We can safely + // remove the session at this point because we have received the only + // response we expect for this session. + switch msg.(type) { + case *rfqmsg.BuyAccept, *rfqmsg.SellAccept, *rfqmsg.Reject: + h.outgoingRequests.Delete(msg.MsgID()) } // Send the incoming message to the RFQ manager. diff --git a/rfqmsg/accept.go b/rfqmsg/accept.go index 9fd68ec0e..b880eb0b7 100644 --- a/rfqmsg/accept.go +++ b/rfqmsg/accept.go @@ -230,7 +230,9 @@ func (m *acceptWireMsgData) Bytes() ([]byte, error) { // asset to us. Conversely, an incoming sell accept message indicates that our // peer accepts our sell request, meaning they are willing to buy the asset from // us. -func NewIncomingAcceptFromWire(wireMsg WireMessage) (IncomingMsg, error) { +func NewIncomingAcceptFromWire(wireMsg WireMessage, + sessionLookup SessionLookup) (IncomingMsg, error) { + // Ensure that the message type is a quote accept message. if wireMsg.MsgType != MsgTypeAccept { return nil, ErrUnknownMessageType @@ -248,17 +250,30 @@ func NewIncomingAcceptFromWire(wireMsg WireMessage) (IncomingMsg, error) { "quote accept message: %w", err) } - // We will now determine whether this is a buy or sell accept. We can - // distinguish between buy/sell accept messages by inspecting which tick - // rate field is populated. - isBuyAccept := msgData.InOutRateTick.IsSome() + // Before we can determine whether this is a buy or sell accept, we need + // to look up the corresponding outgoing request message. This step is + // necessary because the accept message data does not contain sufficient + // data to distinguish between buy and sell accept messages. + if sessionLookup == nil { + return nil, fmt.Errorf("RFQ session lookup function is " + + "required") + } - // If this is a buy request, then we will create a new buy request - // message. - if isBuyAccept { - return newBuyAcceptFromWireMsg(wireMsg, msgData) + request, found := sessionLookup(msgData.ID.Val) + if !found { + return nil, fmt.Errorf("no outgoing request found for "+ + "incoming accept message: %s", msgData.ID.Val) } - // Otherwise, this is a sell request. - return newSellAcceptFromWireMsg(wireMsg, msgData) + // Use the corresponding request to determine the type of accept + // message. + switch typedRequest := request.(type) { + case *BuyRequest: + return newBuyAcceptFromWireMsg(wireMsg, msgData, *typedRequest) + case *SellRequest: + return newSellAcceptFromWireMsg(wireMsg, msgData, *typedRequest) + default: + return nil, fmt.Errorf("unknown request type for incoming "+ + "accept message: %T", request) + } } diff --git a/rfqmsg/buy_accept.go b/rfqmsg/buy_accept.go index edfdf7bbf..19e79f176 100644 --- a/rfqmsg/buy_accept.go +++ b/rfqmsg/buy_accept.go @@ -58,7 +58,7 @@ func NewBuyAcceptFromRequest(request BuyRequest, askPrice lnwire.MilliSatoshi, // newBuyAcceptFromWireMsg instantiates a new instance from a wire message. func newBuyAcceptFromWireMsg(wireMsg WireMessage, - msgData acceptWireMsgData) (*BuyAccept, error) { + msgData acceptWireMsgData, request BuyRequest) (*BuyAccept, error) { // Ensure that the message type is an accept message. if wireMsg.MsgType != MsgTypeAccept { @@ -79,6 +79,7 @@ func newBuyAcceptFromWireMsg(wireMsg WireMessage, return &BuyAccept{ Peer: wireMsg.Peer, + Request: request, Version: msgData.Version.Val, ID: msgData.ID.Val, Expiry: msgData.Expiry.Val, diff --git a/rfqmsg/buy_request.go b/rfqmsg/buy_request.go index 9a6373699..30b34d633 100644 --- a/rfqmsg/buy_request.go +++ b/rfqmsg/buy_request.go @@ -168,6 +168,16 @@ func (q *BuyRequest) ToWire() (WireMessage, error) { }, nil } +// MsgPeer returns the peer that sent the message. +func (q *BuyRequest) MsgPeer() route.Vertex { + return q.Peer +} + +// MsgID returns the quote request session ID. +func (q *BuyRequest) MsgID() ID { + return q.ID +} + // String returns a human-readable string representation of the message. func (q *BuyRequest) String() string { var groupKeyBytes []byte diff --git a/rfqmsg/messages.go b/rfqmsg/messages.go index 0a3371ba3..cc61b1ffa 100644 --- a/rfqmsg/messages.go +++ b/rfqmsg/messages.go @@ -90,13 +90,19 @@ type WireMessage struct { Data []byte } +// SessionLookup is a function that can be used to look up a session quote +// request message given a session ID. +type SessionLookup func(id ID) (OutgoingMsg, bool) + // NewIncomingMsgFromWire creates a new RFQ message from a wire message. -func NewIncomingMsgFromWire(wireMsg WireMessage) (IncomingMsg, error) { +func NewIncomingMsgFromWire(wireMsg WireMessage, + sessionLookup SessionLookup) (IncomingMsg, error) { + switch wireMsg.MsgType { case MsgTypeRequest: return NewIncomingRequestFromWire(wireMsg) case MsgTypeAccept: - return NewIncomingAcceptFromWire(wireMsg) + return NewIncomingAcceptFromWire(wireMsg, sessionLookup) case MsgTypeReject: return NewQuoteRejectFromWireMsg(wireMsg) default: @@ -156,6 +162,12 @@ func WireMsgDataVersionDecoder(r io.Reader, val any, buf *[8]byte, // IncomingMsg is an interface that represents an inbound wire message // that has been received from a peer. type IncomingMsg interface { + // MsgPeer returns the peer that sent the message. + MsgPeer() route.Vertex + + // MsgID returns the quote request session ID. + MsgID() ID + // String returns a human-readable string representation of the message. String() string } diff --git a/rfqmsg/reject.go b/rfqmsg/reject.go index fdf6c9a06..69a1a9ed6 100644 --- a/rfqmsg/reject.go +++ b/rfqmsg/reject.go @@ -243,6 +243,16 @@ func (q *Reject) ToWire() (WireMessage, error) { }, nil } +// MsgPeer returns the peer that sent the message. +func (q *Reject) MsgPeer() route.Vertex { + return q.Peer +} + +// MsgID returns the quote request session ID. +func (q *Reject) MsgID() ID { + return q.ID +} + // String returns a human-readable string representation of the message. func (q *Reject) String() string { return fmt.Sprintf("Reject(id=%x, err_code=%d, err_msg=%s)", diff --git a/rfqmsg/sell_accept.go b/rfqmsg/sell_accept.go index e4c869135..b6204a6c8 100644 --- a/rfqmsg/sell_accept.go +++ b/rfqmsg/sell_accept.go @@ -58,7 +58,8 @@ func NewSellAcceptFromRequest(request SellRequest, bidPrice lnwire.MilliSatoshi, // newSellAcceptFromWireMsg instantiates a new instance from a wire message. func newSellAcceptFromWireMsg(wireMsg WireMessage, - msgData acceptWireMsgData) (*SellAccept, error) { + msgData acceptWireMsgData, request SellRequest) (*SellAccept, + error) { // Ensure that the message type is an accept message. if wireMsg.MsgType != MsgTypeAccept { @@ -82,6 +83,7 @@ func newSellAcceptFromWireMsg(wireMsg WireMessage, // service. return &SellAccept{ Peer: wireMsg.Peer, + Request: request, Version: msgData.Version.Val, ID: msgData.ID.Val, BidPrice: bidPrice, diff --git a/rfqmsg/sell_request.go b/rfqmsg/sell_request.go index 38ee806fb..a34fc90e7 100644 --- a/rfqmsg/sell_request.go +++ b/rfqmsg/sell_request.go @@ -174,6 +174,16 @@ func (q *SellRequest) ToWire() (WireMessage, error) { }, nil } +// MsgPeer returns the peer that sent the message. +func (q *SellRequest) MsgPeer() route.Vertex { + return q.Peer +} + +// MsgID returns the quote request session ID. +func (q *SellRequest) MsgID() ID { + return q.ID +} + // String returns a human-readable string representation of the message. func (q *SellRequest) String() string { var groupKeyBytes []byte