Skip to content

Commit

Permalink
refactor TCP/UDP port internals
Browse files Browse the repository at this point in the history
  • Loading branch information
soypat committed Nov 20, 2023
1 parent 40d0a58 commit 19e1a1a
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 52 deletions.
47 changes: 24 additions & 23 deletions stack/port_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,6 @@ type tcpPort struct {
packets [1]TCPPacket
}

const tcpMTU = _MTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeTCPHeader

type TCPPacket struct {
Rx time.Time
Eth eth.EthernetHeader
IP eth.IPv4Header
TCP eth.TCPHeader
// data contains TCP+IP options and then the actual data.
data [tcpMTU]byte
}

func (p *TCPPacket) String() string {
return "TCP Packet: " + p.Eth.String() + " " + p.IP.String() + " " + p.TCP.String() + " payload:" + strconv.Quote(string(p.Payload()))
}

func (p tcpPort) Port() uint16 { return p.port }

// NeedsHandling returns true if the socket needs handling before it can
Expand All @@ -52,7 +37,7 @@ func (u *tcpPort) NeedsHandling() bool {

// IsPendingHandling returns true if there are packet(s) pending handling.
func (u *tcpPort) IsPendingHandling() bool {
return u.port != 0 && !u.packets[0].Rx.IsZero()
return u.port != 0 && u.packets[0].pendingHandling()
}

// HandleEth writes the socket's response into dst to be sent over an ethernet interface.
Expand All @@ -65,9 +50,9 @@ func (u *tcpPort) HandleEth(dst []byte) (n int, err error) {

n, err = u.handler(dst, &u.packets[0])
if err == ErrFlagPending {
packet.Rx = forcedTime // Mark socket as needing handling but packet having no data.
packet.flagPendingNoPacket() // Mark socket as needing handling but packet having no data.
} else {
packet.Rx = time.Time{} // Invalidate packet normally.
packet.invalidate()
}
return n, err
}
Expand All @@ -80,13 +65,13 @@ func (u *tcpPort) Open(port uint16, handler tcphandler) {
u.handler = handler
u.port = port
for i := range u.packets {
u.packets[i].Rx = time.Time{} // Invalidate packets.
u.packets[i].invalidate()
}
}

func (s *tcpPort) pending() (p uint32) {
for i := range s.packets {
if s.packets[i].HasPacket() {
if s.packets[i].pendingHandling() {
p++
}
}
Expand All @@ -101,15 +86,31 @@ func (u *tcpPort) Close() {
func (u *tcpPort) forceResponse() (added bool) {
if !u.IsPendingHandling() {
added = true
u.packets[0].Rx = forcedTime
u.packets[0].flagPendingNoPacket()
}
return added
}

func (u *TCPPacket) HasPacket() bool {
return u.Rx != forcedTime && !u.Rx.IsZero()
const tcpMTU = _MTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeTCPHeader

type TCPPacket struct {
Rx time.Time
Eth eth.EthernetHeader
IP eth.IPv4Header
TCP eth.TCPHeader
// data contains TCP+IP options and then the actual data.
data [tcpMTU]byte
}

func (p *TCPPacket) String() string {
return "TCP Packet: " + p.Eth.String() + " " + p.IP.String() + " " + p.TCP.String() + " payload:" + strconv.Quote(string(p.Payload()))
}

func (u *TCPPacket) HasPacket() bool { return u.Rx != forcedTime && !u.Rx.IsZero() }
func (u *TCPPacket) pendingHandling() bool { return !u.Rx.IsZero() }
func (u *TCPPacket) invalidate() { u.Rx = time.Time{} }
func (u *TCPPacket) flagPendingNoPacket() { u.Rx = forcedTime }

// PutHeaders puts 54 bytes including the Ethernet, IPv4 and TCP headers into b.
// b must be at least 54 bytes in length or else PutHeaders panics. No options are marshalled.
func (p *TCPPacket) PutHeaders(b []byte) {
Expand Down
33 changes: 17 additions & 16 deletions stack/port_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,6 @@ type udpPort struct {
packets [1]UDPPacket
}

type UDPPacket struct {
Rx time.Time
Eth eth.EthernetHeader
IP eth.IPv4Header
UDP eth.UDPHeader
payload [_MTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeUDPHeader]byte
}

func (u udpPort) Port() uint16 { return u.port }

// NeedsHandling returns true if the socket needs handling before it can
Expand All @@ -37,7 +29,7 @@ func (u *udpPort) NeedsHandling() bool {

// IsPendingHandling returns true if there are packet(s) pending handling.
func (u *udpPort) IsPendingHandling() bool {
return u.port != 0 && !u.packets[0].Rx.IsZero()
return u.port != 0 && u.packets[0].pendingHandling()
}

// HandleEth writes the socket's response into dst to be sent over an ethernet interface.
Expand All @@ -50,9 +42,9 @@ func (u *udpPort) HandleEth(dst []byte) (int, error) {

n, err := u.handler(dst, &u.packets[0])
if err == ErrFlagPending {
packet.Rx = forcedTime // Mark socket as needing handling but packet having no data.
packet.flagPendingNoPacket() // Mark socket as needing handling but packet having no data.
} else {
packet.Rx = time.Time{} // Invalidate packet normally.
packet.invalidate()
}
return n, err
}
Expand All @@ -68,7 +60,7 @@ func (u *udpPort) Open(port uint16, h udphandler) {

func (s *udpPort) pending() (p int) {
for i := range s.packets {
if s.packets[i].HasPacket() {
if s.packets[i].pendingHandling() {
p++
}
}
Expand All @@ -78,7 +70,7 @@ func (s *udpPort) pending() (p int) {
func (u *udpPort) Close() {
u.port = 0 // Port 0 flags the port is inactive.
for i := range u.packets {
u.packets[i].Rx = time.Time{} // Invalidate packets.
u.packets[i].invalidate()
}
}

Expand All @@ -89,15 +81,24 @@ var forcedTime = (time.Time{}).Add(1)
func (u *udpPort) forceResponse() (added bool) {
if !u.IsPendingHandling() {
added = true
u.packets[0].Rx = forcedTime
u.packets[0].flagPendingNoPacket()
}
return added
}

func (u *UDPPacket) HasPacket() bool {
return u.Rx != forcedTime && !u.Rx.IsZero() // TODO simplify this to just IsZero
type UDPPacket struct {
Rx time.Time
Eth eth.EthernetHeader
IP eth.IPv4Header
UDP eth.UDPHeader
payload [_MTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeUDPHeader]byte
}

func (u *UDPPacket) HasPacket() bool { return u.Rx != forcedTime && !u.Rx.IsZero() }
func (u *UDPPacket) pendingHandling() bool { return !u.Rx.IsZero() }
func (u *UDPPacket) invalidate() { u.Rx = time.Time{} }
func (u *UDPPacket) flagPendingNoPacket() { u.Rx = forcedTime }

func (p *UDPPacket) PutHeaders(b []byte) {
if len(b) < eth.SizeEthernetHeader+eth.SizeIPv4Header+eth.SizeUDPHeader {
panic("short UDPPacket buffer")
Expand Down
34 changes: 22 additions & 12 deletions stack/portstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ var ErrFlagPending = io.ErrNoProgress
// - Users can safely use pkt even if pkt.HasPacket() returns false.
//
// - If the handler returns an error that is not ErrFlagPending then the port
// is immediately closed and written data is discarded.
// is immediately closed.
//
// - [io.EOF] and ErrFlagPending: When returned by handler data written is not discarded.
// This means that the handler can write data and close port in same operation returning non-zero `n` and EOF.
//
// - ErrFlagPending: When returned by the handler then the port is flagged as
// pending and the written data is handled normally if there is any. If no data is written
Expand All @@ -74,27 +77,29 @@ type PortStack struct {
lastRx time.Time
lastRxSuccess time.Time
lastTx time.Time
mac [6]byte
// Set ip to non-nil to ignore packets not meant for us.
ip [4]byte
portsUDP []udpPort
portsTCP []tcpPort
glob func([]byte)
glob func([]byte)
logger *slog.Logger
portsUDP []udpPort
portsTCP []tcpPort

pendingUDPv4 uint32
pendingTCPv4 uint32
droppedPackets uint32
processedPackets uint32
// droppedPackets counts amount of packets corresponding to TCP/UDP ports
// that have been dropped due to the port requiring handling before admitting more packets.
droppedPackets uint32
// pending ARP reply that must be sent out.
pendingARPresponse eth.ARPv4Header
ARPresult eth.ARPv4Header
logger *slog.Logger
mac [6]byte
ip [4]byte
}

// Common errors.
var (
ErrDroppedPacket = errors.New("dropped packet")
errPacketExceedsMTU = errors.New("packet exceeds MTU")
errNotIPv4 = errors.New("require IPv4")
// errNotIPv4 = errors.New("require IPv4")
errPacketSmol = errors.New("packet too small")
errTooShortTCPOrUDP = errors.New("packet too short to be TCP/UDP")
errZeroPort = errors.New("zero port in TCP/UDP")
Expand All @@ -113,7 +118,7 @@ var (
func (ps *PortStack) Addr() netip.Addr { return netip.AddrFrom4(ps.ip) }
func (ps *PortStack) SetAddr(addr netip.Addr) {
if !addr.Is4() {
panic("SetAddr only supports IPv4, or argument is not an IP address")
panic("SetAddr only supports IPv4, or argument not initialized")
}
ps.ip = addr.As4()
}
Expand Down Expand Up @@ -358,7 +363,12 @@ func (ps *PortStack) HandleEth(dst []byte) (n int, err error) {
}
if err != nil {
sock.Close()
n = 0
if err == io.EOF {
// Special case: If error is EOF we don't return it to caller but we do write the packet if any.
err = nil
} else {
n = 0 // Clear n on unknown error and return error up the call stack.
}
}
return n, false, err
}
Expand Down
18 changes: 17 additions & 1 deletion stack/socket_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ func (t *TCPSocket) State() seqs.State {
}

func (t *TCPSocket) Send(b []byte) error {
if t.abortErr != nil {
return t.abortErr
}
if t.scb.State() != seqs.StateEstablished {
return errors.New("connection not established")
}
Expand All @@ -75,6 +78,9 @@ func (t *TCPSocket) Send(b []byte) error {
}

func (t *TCPSocket) Recv(b []byte) (int, error) {
if t.abortErr != nil {
return 0, t.abortErr
}
if t.closing {
return 0, io.EOF
}
Expand Down Expand Up @@ -125,6 +131,9 @@ func ListenTCP(stack *PortStack, port uint16, iss seqs.Value, window seqs.Size)
}

func (t *TCPSocket) Close() error {
if t.abortErr != nil {
return t.abortErr
}
toSend := t.tx.Buffered()
if toSend == 0 {
err := t.scb.Close()
Expand All @@ -138,6 +147,9 @@ func (t *TCPSocket) Close() error {
}

func (t *TCPSocket) handleMain(response []byte, pkt *TCPPacket) (n int, err error) {
if t.abortErr != nil {
return 0, t.abortErr // Force close of socket.
}
defer func() {
if err != nil && t.abortErr == nil && err != ErrFlagPending {
err = nil // Only close socket if socket is aborted.
Expand Down Expand Up @@ -226,6 +238,8 @@ func (t *TCPSocket) handleSend(response []byte, pkt *TCPPacket) (n int, err erro

if t.scb.HasPending() {
err = ErrFlagPending // Flag to PortStack that we have pending data to send.
} else if t.scb.State() == seqs.StateClosed {
err = io.EOF
}
return sizeTCPNoOptions + n, err
}
Expand Down Expand Up @@ -262,6 +276,8 @@ func (t *TCPSocket) close() {
t.lastTx = time.Time{}
t.lastRx = time.Time{}
t.closing = false
// t.stack.CloseTCP(t.localPort)
t.abortErr = io.ErrClosedPipe
}

func (t *TCPSocket) synsentSegment() seqs.Segment {
Expand All @@ -274,7 +290,7 @@ func (t *TCPSocket) synsentSegment() seqs.Segment {
}

func (t *TCPSocket) abort(err error) error {
t.abortErr = err
t.close()
t.abortErr = err
return err
}

0 comments on commit 19e1a1a

Please sign in to comment.