diff --git a/control.go b/control.go index 008adfe..096b09d 100644 --- a/control.go +++ b/control.go @@ -51,8 +51,9 @@ type ControlBlock struct { // 1 - old sequence numbers which have been acknowledged // 2 - sequence numbers allowed for new reception // 3 - future sequence numbers which are not yet allowed - rcv recvSpace - rstPtr Value // RST pointer. See RFC 3540. + rcv recvSpace + // When FlagRST is set in pending flags rstPtr will contain the sequence number of the RST segment to make it "believable" (See RFC9293) + rstPtr Value // pending and state are modified by rcv* methods and Close method. // The pending flags are only updated if the Recv method finishes with no error. pending Flags @@ -78,7 +79,7 @@ type recvSpace struct { WND Size // receive window defined by local. Permitted number unacked octets in flight. } -// Segment represents a TCP segment as the sequence number of the first octet and the length of the segment. +// Segment represents an incoming/outgoing TCP segment in the sequence space. type Segment struct { SEQ Value // sequence number of first octet of segment. If SYN is set it is the initial sequence number (ISN) and the first data octet is ISN+1. ACK Value // acknowledgment number. If ACK is set it is sequence number of first octet the sender of the segment is expecting to receive next. @@ -112,6 +113,11 @@ func (tcb *ControlBlock) PendingSegment(payloadLen int) (_ Segment, ok bool) { payloadLen = int(tcb.snd.WND) } + pending := tcb.pending + if payloadLen > 0 { + pending |= FlagPSH + } + var ack Value if tcb.pending.HasAny(FlagACK) { ack = tcb.rcv.NXT @@ -126,7 +132,7 @@ func (tcb *ControlBlock) PendingSegment(payloadLen int) (_ Segment, ok bool) { SEQ: seq, ACK: ack, WND: tcb.rcv.WND, - Flags: tcb.pending, + Flags: pending, DATALEN: Size(payloadLen), } return seg, true @@ -326,6 +332,25 @@ func (tcb *ControlBlock) validateOutgoingSegment(seg Segment) (err error) { return err } +// close sets ControlBlock state to closed and resets all sequence numbers and pending flag. +func (tcb *ControlBlock) close() { + tcb.state = StateClosed + tcb.pending = 0 + tcb.resetRcv(0, 0) + tcb.resetSnd(0, 0) + tcb.debuglog += "close tcb\n" +} + +// hasIRS checks if the ControlBlock has received a valid initial sequence number (IRS). +func (tcb *ControlBlock) hasIRS() bool { + return tcb.isOpen() && tcb.state != StateSynSent && tcb.state != StateListen +} + +// isOpen checks if the ControlBlock is in a state that allows sending and/or receiving data. +func (tcb *ControlBlock) isOpen() bool { + return tcb.state != StateClosed && tcb.state != StateTimeWait +} + // Flags is a TCP flags masked implementation i.e: SYN, FIN, ACK. type Flags uint16 @@ -339,8 +364,10 @@ const ( FlagECE // FlagECE - ECN-Echo has a nonce-sum in the SYN/ACK. FlagCWR // FlagCWR - Congestion Window Reduced. FlagNS // FlagNS - Nonce Sum flag (see RFC 3540). +) - // The union of SYN and ACK flags is commonly found throughout the specification, so we define a shorthand. +// The union of SYN|FIN|PSH and ACK flags is commonly found throughout the specification, so we define unexported shorthands. +const ( synack = FlagSYN | FlagACK finack = FlagFIN | FlagACK pshack = FlagPSH | FlagACK diff --git a/control_user.go b/control_user.go index bc54099..cc6828e 100644 --- a/control_user.go +++ b/control_user.go @@ -31,6 +31,7 @@ func (tcb *ControlBlock) Open(iss Value, wnd Size, state State) (err error) { tcb.state = state tcb.resetRcv(wnd, 0) tcb.resetSnd(iss, 1) + tcb.pending = 0 if state == StateSynSent { tcb.pending = FlagSYN } @@ -98,7 +99,7 @@ func (tcb *ControlBlock) Recv(seg Segment) (err error) { case StateCloseWait: case StateLastAck: if seg.Flags.HasAny(FlagACK) { - tcb.state = StateClosed + tcb.close() } default: err = errors.New("rcv: unexpected state " + tcb.state.String()) @@ -122,8 +123,20 @@ func (tcb *ControlBlock) Recv(seg Segment) (err error) { return err } -// RecvNext returns the next sequence number expected to be received. +// RecvNext returns the next sequence number expected to be received from remote. // This implementation will reject segments that are not the next expected sequence. -func (tcb *ControlBlock) RecvNext() Value { - return tcb.rcv.NXT +// RecvNext returns 0 before StateSynRcvd. +func (tcb *ControlBlock) RecvNext() Value { return tcb.rcv.NXT } + +// ISS returns the initial sequence number of the connection that was defined on a call to Open by user. +func (tcb *ControlBlock) ISS() Value { return tcb.snd.ISS } + +// MaxOutgoingSegmentSize returns the maximum size of a segment that can be sent by taking into account +// the send window size and the unacked data. Returns 0 before StateSynRcvd. +func (tcb *ControlBlock) MaxOutgoingSegmentSize() Size { + if !tcb.hasIRS() { + return 0 // SYN not yet received. + } + unacked := Sizeof(tcb.snd.UNA, tcb.snd.NXT) + return tcb.snd.WND - unacked - 1 // TODO: is this -1 supposed to be here? } diff --git a/stack/stack.go b/stack/portstack.go similarity index 72% rename from stack/stack.go rename to stack/portstack.go index 7b56131..0e432f8 100644 --- a/stack/stack.go +++ b/stack/portstack.go @@ -7,6 +7,7 @@ import ( "io" "log/slog" "net" + "net/netip" "time" "github.com/soypat/seqs/eth" @@ -18,14 +19,14 @@ const ( type StackConfig struct { MAC net.HardwareAddr - IP net.IP + IP netip.Addr MaxUDPConns int MaxTCPConns int } // NewStack creates a ready to use TCP/UDP Stack instance. -func NewStack(cfg StackConfig) *Stack { - var s Stack +func NewStack(cfg StackConfig) *PortStack { + var s PortStack s.MAC = cfg.MAC s.IP = cfg.IP s.UDPv4 = make([]udpSocket, cfg.MaxUDPConns) @@ -33,14 +34,17 @@ func NewStack(cfg StackConfig) *Stack { return &s } -// Stack is a TCP/UDP netlink implementation for muxing packets received into -// their respective sockets with [Stack.RcvEth]. -type Stack struct { +// PortStack implements partial TCP/UDP packet muxing to respective sockets with [PortStack.RcvEth]. +// This implementation limits itself basic header validation and port matching. +// Users of PortStack are expected to implement connection state, packet buffering and retransmission logic. +// - In the case of TCP this means implementing the TCP state machine. +// - In the case of UDP PortStack should be enough to build most applications. +type PortStack struct { lastRx time.Time lastRxSuccess time.Time MAC net.HardwareAddr // Set IP to non-nil to ignore packets not meant for us. - IP net.IP + IP netip.Addr UDPv4 []udpSocket TCPv4 []tcpSocket GlobalHandler func([]byte) @@ -73,16 +77,16 @@ var ( // // If [Stack.HandleEth] is not called often enough prevent packet queue from // filling up on a socket RecvEth will start to return [ErrDroppedPacket]. -func (s *Stack) RecvEth(ethernetFrame []byte) (err error) { +func (ps *PortStack) RecvEth(ethernetFrame []byte) (err error) { var ehdr eth.EthernetHeader var ihdr eth.IPv4Header defer func() { if err != nil { - s.error("Stack.RecvEth", slog.String("err", err.Error()), slog.Any("IP", ihdr)) + ps.error("Stack.RecvEth", slog.String("err", err.Error()), slog.Any("IP", ihdr)) } else { - s.lastRxSuccess = s.lastRx - if s.GlobalHandler != nil { - s.GlobalHandler(ethernetFrame) + ps.lastRxSuccess = ps.lastRx + if ps.GlobalHandler != nil { + ps.GlobalHandler(ethernetFrame) } } }() @@ -90,17 +94,22 @@ func (s *Stack) RecvEth(ethernetFrame []byte) (err error) { if len(payload) < eth.SizeEthernetHeader+eth.SizeIPv4Header { return errPacketSmol } - s.debug("Stack.RecvEth:start", slog.Int("plen", len(payload))) - s.lastRx = time.Now() + ps.debug("Stack.RecvEth:start", slog.Int("plen", len(payload))) + ps.lastRx = time.Now() // Ethernet parsing block ehdr = eth.DecodeEthernetHeader(payload) - if s.MAC != nil && !eth.IsBroadcastHW(ehdr.Destination[:]) && !bytes.Equal(ehdr.Destination[:], s.MAC) { + etype := ehdr.AssertType() + if ps.MAC != nil && !eth.IsBroadcastHW(ehdr.Destination[:]) && !bytes.Equal(ehdr.Destination[:], ps.MAC) { return nil // Ignore packet, is not for us. - } else if ehdr.AssertType() != eth.EtherTypeIPv4 { + } else if etype != eth.EtherTypeIPv4 && etype != eth.EtherTypeARP { return nil // Ignore Non-IPv4 packets. } + if etype == eth.EtherTypeARP { + + } + // IP parsing block. var ipOffset uint8 ihdr, ipOffset = eth.DecodeIPv4Header(payload[eth.SizeEthernetHeader:]) @@ -111,7 +120,8 @@ func (s *Stack) RecvEth(ethernetFrame []byte) (err error) { return errIPVersion case ipOffset < eth.SizeIPv4Header: return errInvalidIHL - case s.IP != nil && string(ihdr.Destination[:]) != string(s.IP): + + case ps.IP.Compare(netip.AddrFrom4(ihdr.Destination)) != 0: return nil // Not for us. case uint16(offset) > end || int(offset) > len(payload) || int(end) > len(payload): return errors.New("bad IP TotalLength/IHL") @@ -121,9 +131,10 @@ func (s *Stack) RecvEth(ethernetFrame []byte) (err error) { ipOptions := payload[eth.SizeEthernetHeader+eth.SizeIPv4Header : offset] // TODO add IPv4 options. payload = payload[offset:end] switch ihdr.Protocol { + case 17: // UDP (User Datagram Protocol). - if len(s.UDPv4) == 0 { + if len(ps.UDPv4) == 0 { return nil // No sockets. } else if len(payload) < eth.SizeUDPHeader { return errTooShortTCPOrUDP @@ -142,21 +153,21 @@ func (s *Stack) RecvEth(ethernetFrame []byte) (err error) { return errChecksumTCPorUDP } - socket := s.getUDP(uhdr.DestinationPort) + socket := ps.getUDP(uhdr.DestinationPort) if socket == nil { break // No socket listening on this port. } else if socket.NeedsHandling() { - s.error("UDP packet dropped") - s.droppedPackets++ + ps.error("UDP packet dropped") + ps.droppedPackets++ return ErrDroppedPacket // Our socket needs handling before admitting more packets. } // The packet is meant for us. We handle it. - s.info("UDP packet stored", slog.Int("plen", len(payload))) + ps.info("UDP packet stored", slog.Int("plen", len(payload))) // Flag packets as needing processing. - s.pendingUDPv4++ - socket.LastRx = s.lastRx // set as unhandled here. + ps.pendingUDPv4++ + socket.LastRx = ps.lastRx // set as unhandled here. - socket.packets[0].Rx = s.lastRx + socket.packets[0].Rx = ps.lastRx socket.packets[0].Eth = ehdr socket.packets[0].IP = ihdr socket.packets[0].UDP = uhdr @@ -164,10 +175,10 @@ func (s *Stack) RecvEth(ethernetFrame []byte) (err error) { copy(socket.packets[0].payload[:], payload) case 6: - s.info("TCP packet received", slog.Int("plen", len(payload))) + ps.info("TCP packet received", slog.Int("plen", len(payload))) // TCP (Transport Control Protocol). switch { - case len(s.TCPv4) == 0: + case len(ps.TCPv4) == 0: return nil case len(payload) < eth.SizeTCPHeader: return errTooShortTCPOrUDP @@ -188,20 +199,20 @@ func (s *Stack) RecvEth(ethernetFrame []byte) (err error) { println("bad checksum") } - socket := s.getTCP(thdr.DestinationPort) + socket := ps.getTCP(thdr.DestinationPort) if socket == nil { break // No socket listening on this port. } else if socket.NeedsHandling() { - s.error("TCP packet dropped") - s.droppedPackets++ + ps.error("TCP packet dropped") + ps.droppedPackets++ return ErrDroppedPacket // Our socket needs handling before admitting more packets. } - s.info("TCP packet stored", slog.Int("plen", len(payload))) + ps.info("TCP packet stored", slog.Int("plen", len(payload))) // Flag packets as needing processing. - s.pendingTCPv4++ - socket.LastRx = s.lastRx // set as unhandled here. + ps.pendingTCPv4++ + socket.LastRx = ps.lastRx // set as unhandled here. - socket.packets[0].Rx = s.lastRx + socket.packets[0].Rx = ps.lastRx socket.packets[0].Eth = ehdr socket.packets[0].IP = ihdr socket.packets[0].TCP = thdr @@ -218,18 +229,18 @@ func (s *Stack) RecvEth(ethernetFrame []byte) (err error) { // not processed and that a future call to HandleEth is required to complete. // // If a handler returns any other error the port is closed. -func (s *Stack) HandleEth(dst []byte) (n int, err error) { +func (ps *PortStack) HandleEth(dst []byte) (n int, err error) { switch { case len(dst) < _MTU: return 0, io.ErrShortBuffer - case s.pendingUDPv4 == 0 && s.pendingTCPv4 == 0: + case ps.pendingUDPv4 == 0 && ps.pendingTCPv4 == 0: return 0, nil // No packets to handle } - s.info("HandleEth", slog.Int("dstlen", len(dst))) - if s.pendingUDPv4 > 0 { - for i := range s.UDPv4 { - socket := &s.UDPv4[i] + ps.info("HandleEth", slog.Int("dstlen", len(dst))) + if ps.pendingUDPv4 > 0 { + for i := range ps.UDPv4 { + socket := &ps.UDPv4[i] if !socket.IsPendingHandling() { return 0, nil } @@ -240,7 +251,7 @@ func (s *Stack) HandleEth(dst []byte) (n int, err error) { err = nil continue } - s.pendingUDPv4-- + ps.pendingUDPv4-- if err != nil { socket.Close() return 0, err @@ -252,8 +263,8 @@ func (s *Stack) HandleEth(dst []byte) (n int, err error) { } } - if n == 0 && s.pendingTCPv4 > 0 { - socketList := s.TCPv4 + if n == 0 && ps.pendingTCPv4 > 0 { + socketList := ps.TCPv4 for i := range socketList { socket := &socketList[i] if !socket.IsPendingHandling() { @@ -266,7 +277,7 @@ func (s *Stack) HandleEth(dst []byte) (n int, err error) { err = nil continue } - s.pendingTCPv4-- + ps.pendingTCPv4-- if err != nil { socket.Close() return 0, err @@ -279,14 +290,14 @@ func (s *Stack) HandleEth(dst []byte) (n int, err error) { } if n != 0 && err == nil { - s.processedPackets++ + ps.processedPackets++ } return n, err } // OpenUDP opens a UDP port and sets the handler. If the port is already open // or if there is no socket available it returns an error. -func (s *Stack) OpenUDP(port uint16, handler func([]byte, *UDPPacket) (int, error)) error { +func (ps *PortStack) OpenUDP(port uint16, handler func([]byte, *UDPPacket) (int, error)) error { switch { case port == 0: return errZeroPort @@ -294,7 +305,7 @@ func (s *Stack) OpenUDP(port uint16, handler func([]byte, *UDPPacket) (int, erro return errNilHandler } availIdx := -1 - socketList := s.UDPv4 + socketList := ps.UDPv4 for i := range socketList { socket := &socketList[i] if socket.Port == port { @@ -313,7 +324,7 @@ func (s *Stack) OpenUDP(port uint16, handler func([]byte, *UDPPacket) (int, erro // FlagUDPPending flags the socket listening on a given port as having a pending // packet. This is useful to force a response even if no packet has been received. -func (s *Stack) FlagUDPPending(port uint16) error { +func (s *PortStack) FlagUDPPending(port uint16) error { if port == 0 { return errZeroPort } @@ -328,20 +339,20 @@ func (s *Stack) FlagUDPPending(port uint16) error { } // CloseUDP closes a UDP socket. -func (s *Stack) CloseUDP(port uint16) error { +func (ps *PortStack) CloseUDP(port uint16) error { if port == 0 { return errZeroPort } - socket := s.getUDP(port) + socket := ps.getUDP(port) if socket == nil { return errNoSocketAvail } - s.pendingUDPv4 -= uint32(socket.pending()) + ps.pendingUDPv4 -= uint32(socket.pending()) socket.Close() return nil } -func (s *Stack) getUDP(port uint16) *udpSocket { +func (s *PortStack) getUDP(port uint16) *udpSocket { for i := range s.UDPv4 { socket := &s.UDPv4[i] if socket.Port == port { @@ -353,7 +364,7 @@ func (s *Stack) getUDP(port uint16) *udpSocket { // OpenTCP opens a TCP port and sets the handler. If the port is already open // or if there is no socket available it returns an error. -func (s *Stack) OpenTCP(port uint16, handler tcphandler) error { +func (ps *PortStack) OpenTCP(port uint16, handler tcphandler) error { switch { case port == 0: return errZeroPort @@ -362,7 +373,7 @@ func (s *Stack) OpenTCP(port uint16, handler tcphandler) error { } availIdx := -1 - socketList := s.TCPv4 + socketList := ps.TCPv4 for i := range socketList { socket := &socketList[i] if socket.Port == port { @@ -381,37 +392,37 @@ func (s *Stack) OpenTCP(port uint16, handler tcphandler) error { // FlagTCPPending flags the socket listening on a given port as having a pending // packet. This is useful to force a response even if no packet has been received. -func (s *Stack) FlagTCPPending(port uint16) error { +func (ps *PortStack) FlagTCPPending(port uint16) error { if port == 0 { return errZeroPort } - socket := s.getTCP(port) + socket := ps.getTCP(port) if socket == nil { return errNoSocketAvail } if socket.forceResponse() { - s.pendingTCPv4++ + ps.pendingTCPv4++ } return nil } // CloseTCP closes a TCP socket. -func (s *Stack) CloseTCP(port uint16) error { +func (ps *PortStack) CloseTCP(port uint16) error { if port == 0 { return errZeroPort } - socket := s.getTCP(port) + socket := ps.getTCP(port) if socket == nil { return errNoSocketAvail } - s.pendingTCPv4 -= socket.pending() + ps.pendingTCPv4 -= socket.pending() socket.Close() return nil } -func (s *Stack) getTCP(port uint16) *tcpSocket { - for i := range s.TCPv4 { - socket := &s.TCPv4[i] +func (ps *PortStack) getTCP(port uint16) *tcpSocket { + for i := range ps.TCPv4 { + socket := &ps.TCPv4[i] if socket.Port == port { return socket } @@ -419,20 +430,20 @@ func (s *Stack) getTCP(port uint16) *tcpSocket { return nil } -func (s *Stack) info(msg string, attrs ...slog.Attr) { - s.logAttrsPrint(slog.LevelInfo, msg, attrs...) +func (ps *PortStack) info(msg string, attrs ...slog.Attr) { + ps.logAttrsPrint(slog.LevelInfo, msg, attrs...) } -func (s *Stack) error(msg string, attrs ...slog.Attr) { - s.logAttrsPrint(slog.LevelError, msg, attrs...) +func (ps *PortStack) error(msg string, attrs ...slog.Attr) { + ps.logAttrsPrint(slog.LevelError, msg, attrs...) } -func (s *Stack) debug(msg string, attrs ...slog.Attr) { - s.logAttrsPrint(slog.LevelDebug, msg, attrs...) +func (ps *PortStack) debug(msg string, attrs ...slog.Attr) { + ps.logAttrsPrint(slog.LevelDebug, msg, attrs...) } -func (s *Stack) logAttrsPrint(level slog.Level, msg string, attrs ...slog.Attr) { - if s.level <= level { +func (ps *PortStack) logAttrsPrint(level slog.Level, msg string, attrs ...slog.Attr) { + if ps.level <= level { logAttrsPrint(level, msg, attrs...) } } diff --git a/stack/socket_tcp.go b/stack/socket_tcp.go index 5e5a89b..2dcb231 100644 --- a/stack/socket_tcp.go +++ b/stack/socket_tcp.go @@ -108,6 +108,9 @@ func (p *TCPPacket) PutHeaders(b []byte) { if len(b) < minSize { panic("short tcpPacket buffer") } + if p.IP.IHL() != 5 || p.TCP.Offset() != 5 { + panic("TCPPacket.PutHeaders expects no IP or TCP options") + } p.Eth.Put(b) p.IP.Put(b[eth.SizeEthernetHeader:]) p.TCP.Put(b[eth.SizeEthernetHeader+eth.SizeIPv4Header:]) diff --git a/stack/socket_udp.go b/stack/socket_udp.go index 6e7bd57..0ddf627 100644 --- a/stack/socket_udp.go +++ b/stack/socket_udp.go @@ -100,6 +100,9 @@ func (p *UDPPacket) PutHeaders(b []byte) { if len(b) < eth.SizeEthernetHeader+eth.SizeIPv4Header+eth.SizeUDPHeader { panic("short UDPPacket buffer") } + if p.IP.IHL() != 5 { + panic("UDPPacket.PutHeaders expects no IP options") + } p.Eth.Put(b) p.IP.Put(b[eth.SizeEthernetHeader:]) p.UDP.Put(b[eth.SizeEthernetHeader+eth.SizeIPv4Header:]) diff --git a/stack/stack_test.go b/stack/stack_test.go index 3123cb9..df8227f 100644 --- a/stack/stack_test.go +++ b/stack/stack_test.go @@ -2,7 +2,6 @@ package stack_test import ( "errors" - "fmt" "net/netip" "strings" "testing" @@ -31,7 +30,7 @@ func TestStackEstablish(t *testing.T) { ClientTCB := seqs.ControlBlock{} Client := stack.NewStack(stack.StackConfig{ MAC: macClient[:], - IP: ipClient.Addr().AsSlice(), + IP: ipClient.Addr(), MaxTCPConns: 1, }) @@ -96,93 +95,74 @@ func TestStackEstablish(t *testing.T) { t.Fatal(err) } - ServerTCB := seqs.ControlBlock{} - err = ServerTCB.Open(serverISS, clientISS, seqs.StateListen) - if err != nil { - t.Fatal(err) - } Server := stack.NewStack(stack.StackConfig{ MAC: macServer[:], - IP: ipServer.Addr().AsSlice(), + IP: ipServer.Addr(), MaxTCPConns: 1, }) - err = Server.OpenTCP(ipServer.Port(), func(response []byte, pkt *stack.TCPPacket) (n int, err error) { - defer func() { - if n > 0 && err == nil { - t.Logf("Server sent: %s", pkt.String()) - } - }() - if pkt.HasPacket() { - t.Logf("Server received: %s", pkt.String()) - payload := pkt.Payload() - err = ServerTCB.Recv(pkt.TCP.Segment(len(payload))) - if err != nil { - return 0, err - } - segOut, ok := ServerTCB.PendingSegment(0) - if !ok { - return 0, nil - } - pkt.InvertSrcDest() - pkt.CalculateHeaders(segOut, nil) - pkt.PutHeaders(response) - return 54, ServerTCB.Send(segOut) - } - return 0, nil - }) + serverTCP, err := stack.ListenTCP(Server, ipServer.Port(), serverISS, serverWND) if err != nil { t.Fatal(err) } + // 3 way handshake needs 3 exchanges to complete. + const maxExchanges = 3 + exchanges, dataExchanged := exchangeStacks(t, maxExchanges, Client, Server) + const expectedData = (eth.SizeEthernetHeader + eth.SizeIPv4Header + eth.SizeTCPHeader) * 4 + if dataExchanged < expectedData { + t.Fatal("too little data exchanged", dataExchanged, " want>=", expectedData) + } + if exchanges >= 4 { + t.Fatal("too many exchanges for a 3 way handshake") + } + if exchanges <= 2 { + t.Fatal("too few exchanges for a 3 way handshake") + } + if ClientTCB.State() != seqs.StateEstablished { + t.Fatal("client not established") + } + if serverTCP.State() != seqs.StateEstablished { + t.Fatal("server not established") + } +} + +func isDroppedPacket(err error) bool { + return err != nil && (errors.Is(err, stack.ErrDroppedPacket) || strings.HasPrefix(err.Error(), "drop")) +} + +func exchangeStacks(t *testing.T, maxExchanges int, stacks ...*stack.PortStack) (exchanges, totalData int) { + loops := 0 var pipe [2048]byte zeroPipe := func() { pipe = [2048]byte{} } sprintErr := func(err error) string { - return fmt.Sprintf("%v: client=%s server=%s", err, ClientTCB.State(), ServerTCB.State()) + return err.Error() + // return fmt.Sprintf("%v: client=%s server=%s", err, ClientTCB.State(), ServerTCB.State()) } - - // 3 way handshake needs 3 exchanges to complete. - const maxExchanges = 3 - loops := 0 + totalDataSent := 0 for loops <= maxExchanges { loops++ - nc, err := Client.HandleEth(pipe[:]) - if err != nil && !isDroppedPacket(err) { - t.Fatal("client handle:", sprintErr(err)) - } - if nc > 0 { - err = Server.RecvEth(pipe[:nc]) + sent := 0 + for isender := 0; isender < len(stacks); isender++ { + n, err := stacks[isender].HandleEth(pipe[:]) if err != nil && !isDroppedPacket(err) { - t.Fatal("sv recv:", sprintErr(err)) + t.Fatalf("send[%d]: %s", isender, sprintErr(err)) } - zeroPipe() - } - - ns, err := Server.HandleEth(pipe[:]) - if err != nil && !isDroppedPacket(err) { - t.Fatal("sv handle:", sprintErr(err)) - } - if ns > 0 { - err = Client.RecvEth(pipe[:ns]) - if err != nil && !isDroppedPacket(err) { - t.Fatal("client recv:", sprintErr(err)) + sent += n + for ireceiver := 0; n > 0 && ireceiver < len(stacks); ireceiver++ { + if ireceiver == isender { + continue + } + err = stacks[ireceiver].RecvEth(pipe[:n]) + if err != nil && !isDroppedPacket(err) { + t.Fatalf("recv[%d]: %s", ireceiver, sprintErr(err)) + } } zeroPipe() } - if ns == 0 && nc == 0 { + totalDataSent += sent + if sent == 0 { break // No more data being interchanged. } } - if loops > maxExchanges { - t.Fatal("unending connection established") - } - if ClientTCB.State() != seqs.StateEstablished { - t.Fatal("client not established") - } - if ServerTCB.State() != seqs.StateEstablished { - t.Fatal("server not established") - } -} - -func isDroppedPacket(err error) bool { - return err != nil && (errors.Is(err, stack.ErrDroppedPacket) || strings.HasPrefix(err.Error(), "drop")) + return loops, totalDataSent } diff --git a/stack/tcpstack.go b/stack/tcpstack.go new file mode 100644 index 0000000..5987baa --- /dev/null +++ b/stack/tcpstack.go @@ -0,0 +1,114 @@ +package stack + +import ( + "net/netip" + "time" + + "github.com/soypat/seqs" +) + +type tcp struct { + stack *PortStack + scb seqs.ControlBlock + localPort uint16 + iss seqs.Value + wnd seqs.Size + lastTx time.Time + lastRx time.Time + // Remote fields discovered during an active open. + remote netip.AddrPort + remoteMAC [6]byte +} + +func (t *tcp) State() seqs.State { + return t.scb.State() +} + +// ListenTCP opens a passive TCP connection that listens on the given port. +func ListenTCP(stack *PortStack, port uint16, iss seqs.Value, window seqs.Size) (*tcp, error) { + t := tcp{ + stack: stack, + localPort: port, + } + err := stack.OpenTCP(port, t.handleMain) + if err != nil { + return nil, err + } + err = t.scb.Open(iss, window, seqs.StateListen) + if err != nil { + return nil, err + } + return &t, nil +} + +func (t *tcp) handleMain(response []byte, pkt *TCPPacket) (n int, err error) { + if t.mustSendSyn() { + // Connection is still closed, we need to establish + return t.handleInitSyn(response, pkt) + } + if pkt.HasPacket() { + t.lastRx = pkt.Rx + n, err := t.handleRecv(response, pkt) + if n > 0 || err != nil { + return n, err // Return early if something happened, else yield to user data handler. + } + } + return t.handleUser(response, pkt) +} + +func (t *tcp) handleRecv(response []byte, pkt *TCPPacket) (n int, err error) { + // By this point we know that the packet is valid and contains data, we process it. + payload := pkt.Payload() + segIncoming := pkt.TCP.Segment(len(payload)) + // if segIncoming.SEQ != t.scb.RecvNext() { + // return 0, ErrDroppedPacket // SCB does not admit out-of-order packets. + // } + err = t.scb.Recv(segIncoming) + if err != nil { + return 0, err + } + segOut, ok := t.scb.PendingSegment(0) + if !ok { + return 0, nil // Yield to handleUser. + } + pkt.InvertSrcDest() + pkt.CalculateHeaders(segOut, nil) + pkt.PutHeaders(response) + return 54, t.scb.Send(segOut) +} + +func (t *tcp) handleUser(response []byte, pkt *TCPPacket) (n int, err error) { + + return 0, nil +} + +func (t *tcp) handleInitSyn(response []byte, pkt *TCPPacket) (n int, err error) { + // Uninitialized TCB, we start the handshake. + iss := t.iss + wnd := t.wnd + err = t.scb.Open(iss, wnd, seqs.StateSynSent) + if err != nil { + return 0, err + } + outSeg := seqs.Segment{ + SEQ: iss, + ACK: 0, + Flags: seqs.FlagSYN, + WND: wnd, + } + copy(pkt.Eth.Source[:], t.stack.MAC) + pkt.IP.Source = t.stack.IP.As4() + pkt.TCP.SourcePort = t.localPort + + pkt.IP.Destination = t.remote.Addr().As4() + pkt.TCP.DestinationPort = t.remote.Port() + pkt.Eth.Destination = t.remoteMAC + + pkt.CalculateHeaders(outSeg, nil) + pkt.PutHeaders(response) + return 54, t.scb.Send(outSeg) +} + +func (t *tcp) mustSendSyn() bool { + return t.lastTx.IsZero() && t.scb.State() == seqs.StateClosed +}