diff --git a/stack/port_tcp.go b/stack/port_tcp.go index 9b138b4..cb6e902 100644 --- a/stack/port_tcp.go +++ b/stack/port_tcp.go @@ -24,31 +24,28 @@ type tcpPort struct { packets [1]TCPPacket } -func (p tcpPort) Port() uint16 { return p.port } +func (port tcpPort) Port() uint16 { return port.port } // NeedsHandling returns true if the socket needs handling before it can // admit more pending packets. -func (u *tcpPort) NeedsHandling() bool { - // As of now socket has space for 1 packet so if packet is pending, queue is full. - // Compile time check to ensure this is fulfilled: - _ = u.packets[1-len(u.packets)] - return u.IsPendingHandling() +func (port *tcpPort) NeedsHandling() bool { + return port.freePacket() == nil } // IsPendingHandling returns true if there are packet(s) pending handling. -func (u *tcpPort) IsPendingHandling() bool { - return u.port != 0 && u.packets[0].pendingHandling() +func (port *tcpPort) IsPendingHandling() bool { + return port.port != 0 && port.nextPacket().pendingHandling() } // HandleEth writes the socket's response into dst to be sent over an ethernet interface. // HandleEth can return 0 bytes written and a nil error to indicate no action must be taken. -func (u *tcpPort) HandleEth(dst []byte) (n int, err error) { - if u.handler == nil { - panic("nil tcp handler on port " + strconv.Itoa(int(u.port))) +func (port *tcpPort) HandleEth(dst []byte) (n int, err error) { + if port.handler == nil { + panic("nil tcp handler on port " + strconv.Itoa(int(port.port))) } - packet := &u.packets[0] - n, err = u.handler(dst, &u.packets[0]) + packet := port.nextPacket() + n, err = port.handler(dst, packet) if err == ErrFlagPending { packet.flagPendingNoPacket() // Mark socket as needing handling but packet having no data. } else { @@ -57,38 +54,60 @@ func (u *tcpPort) HandleEth(dst []byte) (n int, err error) { return n, err } +// nextPacket returns the next packet that is pending handling or the first packet if none are pending. +func (port *tcpPort) nextPacket() *TCPPacket { + for i := range port.packets { + if port.packets[i].pendingHandling() { + return &port.packets[i] + } + } + return &port.packets[0] +} + +// freePacket returns the first packet that is not pending handling or nil if all packets are pending. +func (port *tcpPort) freePacket() *TCPPacket { + for i := range port.packets { + if !port.packets[i].pendingHandling() { + return &port.packets[i] + } + } + return nil +} + // Open sets the UDP handler and opens the port. -func (u *tcpPort) Open(port uint16, handler tcphandler) { - if port == 0 || handler == nil { - panic("invalid port or nil handler" + strconv.Itoa(int(u.port))) +func (port *tcpPort) Open(portNum uint16, handler tcphandler) { + if portNum == 0 || handler == nil { + panic("invalid port or nil handler" + strconv.Itoa(int(port.port))) } - u.handler = handler - u.port = port - for i := range u.packets { - u.packets[i].invalidate() + port.handler = handler + port.port = portNum + for i := range port.packets { + port.packets[i].invalidate() } } -func (s *tcpPort) pending() (p uint32) { - for i := range s.packets { - if s.packets[i].pendingHandling() { +func (port *tcpPort) pending() (p uint32) { + for i := range port.packets { + if port.packets[i].pendingHandling() { p++ } } return p } -func (u *tcpPort) Close() { - u.handler = nil - u.port = 0 // Port 0 flags the port is inactive. +func (port *tcpPort) Close() { + port.handler = nil + port.port = 0 // Port 0 flags the port is inactive. } -func (u *tcpPort) forceResponse() (added bool) { - if !u.IsPendingHandling() { - added = true - u.packets[0].flagPendingNoPacket() +func (port *tcpPort) forceResponse() (added bool) { + for i := range port.packets { + if !port.packets[i].pendingHandling() { + port.packets[i].flagPendingNoPacket() + return true + } } - return added + return false } const tcpMTU = _MTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeTCPHeader @@ -102,31 +121,31 @@ type TCPPacket struct { data [tcpMTU]byte } -func (p *TCPPacket) String() string { - return "TCP Packet: " + p.Eth.String() + " " + p.IP.String() + " " + p.TCP.String() + " payload:" + strconv.Quote(string(p.Payload())) +func (pkt *TCPPacket) String() string { + return "TCP Packet: " + pkt.Eth.String() + " " + pkt.IP.String() + " " + pkt.TCP.String() + " payload:" + strconv.Quote(string(pkt.Payload())) } -func (u *TCPPacket) HasPacket() bool { return u.Rx != forcedTime && !u.Rx.IsZero() } -func (u *TCPPacket) pendingHandling() bool { return !u.Rx.IsZero() } -func (u *TCPPacket) invalidate() { u.Rx = time.Time{} } -func (u *TCPPacket) flagPendingNoPacket() { u.Rx = forcedTime } +func (pkt *TCPPacket) HasPacket() bool { return pkt.Rx != forcedTime && !pkt.Rx.IsZero() } +func (pkt *TCPPacket) pendingHandling() bool { return !pkt.Rx.IsZero() } +func (pkt *TCPPacket) invalidate() { pkt.Rx = time.Time{} } +func (pkt *TCPPacket) flagPendingNoPacket() { pkt.Rx = forcedTime } // PutHeaders puts 54 bytes including the Ethernet, IPv4 and TCP headers into b. // b must be at least 54 bytes in length or else PutHeaders panics. No options are marshalled. -func (p *TCPPacket) PutHeaders(b []byte) { +func (pkt *TCPPacket) PutHeaders(b []byte) { const minSize = eth.SizeEthernetHeader + eth.SizeIPv4Header + eth.SizeTCPHeader if len(b) < minSize { panic("short tcpPacket buffer") } - if p.IP.IHL() != 5 || p.TCP.Offset() != 5 { + if pkt.IP.IHL() != 5 || pkt.TCP.Offset() != 5 { panic("TCPPacket.PutHeaders expects no IP or TCP options") } - p.Eth.Put(b) - p.IP.Put(b[eth.SizeEthernetHeader:]) - p.TCP.Put(b[eth.SizeEthernetHeader+eth.SizeIPv4Header:]) + pkt.Eth.Put(b) + pkt.IP.Put(b[eth.SizeEthernetHeader:]) + pkt.TCP.Put(b[eth.SizeEthernetHeader+eth.SizeIPv4Header:]) } -func (p *TCPPacket) PutHeadersWithOptions(b []byte) error { +func (pkt *TCPPacket) PutHeadersWithOptions(b []byte) error { const minSize = eth.SizeEthernetHeader + eth.SizeIPv4Header + eth.SizeTCPHeader if len(b) < minSize { panic("short tcpPacket buffer") @@ -136,48 +155,48 @@ func (p *TCPPacket) PutHeadersWithOptions(b []byte) error { // Payload returns the TCP payload. If TCP or IPv4 header data is incorrect/bad it returns nil. // If the response is "forced" then payload will be nil. -func (p *TCPPacket) Payload() []byte { - if !p.HasPacket() { +func (pkt *TCPPacket) Payload() []byte { + if !pkt.HasPacket() { return nil } - payloadStart, payloadEnd, _ := p.dataPtrs() + payloadStart, payloadEnd, _ := pkt.dataPtrs() if payloadStart < 0 { return nil // Bad header value } - return p.data[payloadStart:payloadEnd] + return pkt.data[payloadStart:payloadEnd] } // Options returns the TCP options in the packet. -func (p *TCPPacket) TCPOptions() []byte { - if !p.HasPacket() { +func (pkt *TCPPacket) TCPOptions() []byte { + if !pkt.HasPacket() { return nil } - payloadStart, _, tcpOptStart := p.dataPtrs() + payloadStart, _, tcpOptStart := pkt.dataPtrs() if payloadStart < 0 { return nil // Bad header value } - return p.data[tcpOptStart:payloadStart] + return pkt.data[tcpOptStart:payloadStart] } // Options returns the TCP options in the packet. -func (p *TCPPacket) IPOptions() []byte { - if !p.HasPacket() { +func (pkt *TCPPacket) IPOptions() []byte { + if !pkt.HasPacket() { return nil } - _, _, tcpOpts := p.dataPtrs() + _, _, tcpOpts := pkt.dataPtrs() if tcpOpts < 0 { return nil // Bad header value } - return p.data[:tcpOpts] + return pkt.data[:tcpOpts] } //go:inline -func (p *TCPPacket) dataPtrs() (payloadStart, payloadEnd, tcpOptStart int) { - tcpOptStart = int(4*p.IP.IHL()) - eth.SizeIPv4Header - payloadStart = tcpOptStart + int(p.TCP.OffsetInBytes()) - eth.SizeTCPHeader - payloadEnd = int(p.IP.TotalLength) - tcpOptStart - eth.SizeTCPHeader - eth.SizeIPv4Header +func (pkt *TCPPacket) dataPtrs() (payloadStart, payloadEnd, tcpOptStart int) { + tcpOptStart = int(4*pkt.IP.IHL()) - eth.SizeIPv4Header + payloadStart = tcpOptStart + int(pkt.TCP.OffsetInBytes()) - eth.SizeTCPHeader + payloadEnd = int(pkt.IP.TotalLength) - tcpOptStart - eth.SizeTCPHeader - eth.SizeIPv4Header if payloadStart < 0 || payloadEnd < 0 || tcpOptStart < 0 || payloadStart > payloadEnd || - payloadEnd > len(p.data) || tcpOptStart > payloadStart { + payloadEnd > len(pkt.data) || tcpOptStart > payloadStart { return -1, -1, -1 } return payloadStart, payloadEnd, tcpOptStart @@ -210,6 +229,7 @@ func (pkt *TCPPacket) CalculateHeaders(seg seqs.Segment, payload []byte) { pkt.IP.Flags = 0 // TCP frame. const offset = 5 + pkt.TCP = eth.TCPHeader{ SourcePort: pkt.TCP.SourcePort, DestinationPort: pkt.TCP.DestinationPort, diff --git a/stack/port_udp.go b/stack/port_udp.go index 0afe746..b1bd3eb 100644 --- a/stack/port_udp.go +++ b/stack/port_udp.go @@ -16,31 +16,28 @@ type udpPort struct { packets [1]UDPPacket } -func (u udpPort) Port() uint16 { return u.port } +func (port udpPort) Port() uint16 { return port.port } // NeedsHandling returns true if the socket needs handling before it can // admit more pending packets. -func (u *udpPort) NeedsHandling() bool { - // As of now socket has space for 1 packet so if packet is pending, queue is full. - // Compile time check to ensure this is fulfilled: - _ = u.packets[1-len(u.packets)] - return u.IsPendingHandling() +func (port *udpPort) NeedsHandling() bool { + return port.freePacket() == nil } // IsPendingHandling returns true if there are packet(s) pending handling. -func (u *udpPort) IsPendingHandling() bool { - return u.port != 0 && u.packets[0].pendingHandling() +func (port *udpPort) IsPendingHandling() bool { + return port.port != 0 && port.nextPacket().pendingHandling() } // HandleEth writes the socket's response into dst to be sent over an ethernet interface. // HandleEth can return 0 bytes written and a nil error to indicate no action must be taken. -func (u *udpPort) HandleEth(dst []byte) (int, error) { - if u.handler == nil { - panic("nil udp handler on port " + strconv.Itoa(int(u.port))) +func (port *udpPort) HandleEth(dst []byte) (int, error) { + if port.handler == nil { + panic("nil udp handler on port " + strconv.Itoa(int(port.port))) } - packet := &u.packets[0] + packet := port.nextPacket() - n, err := u.handler(dst, &u.packets[0]) + n, err := port.handler(dst, packet) if err == ErrFlagPending { packet.flagPendingNoPacket() // Mark socket as needing handling but packet having no data. } else { @@ -50,40 +47,62 @@ func (u *udpPort) HandleEth(dst []byte) (int, error) { } // Open sets the UDP handler and opens the port. -func (u *udpPort) Open(port uint16, h udphandler) { - if port == 0 || h == nil { - panic("invalid port or nil handler" + strconv.Itoa(int(u.port))) +func (port *udpPort) Open(portNum uint16, h udphandler) { + if portNum == 0 || h == nil { + panic("invalid port or nil handler" + strconv.Itoa(int(port.port))) } - u.handler = h - u.port = port + port.handler = h + port.port = portNum } -func (s *udpPort) pending() (p int) { - for i := range s.packets { - if s.packets[i].pendingHandling() { +func (port *udpPort) pending() (p int) { + for i := range port.packets { + if port.packets[i].pendingHandling() { p++ } } return p } -func (u *udpPort) Close() { - u.port = 0 // Port 0 flags the port is inactive. - for i := range u.packets { - u.packets[i].invalidate() +func (port *udpPort) Close() { + port.port = 0 // Port 0 flags the port is inactive. + for i := range port.packets { + port.packets[i].invalidate() } } +// nextPacket returns the next packet that is pending handling or the first packet if none are pending. +func (port *udpPort) nextPacket() *UDPPacket { + for i := range port.packets { + if port.packets[i].pendingHandling() { + return &port.packets[i] + } + } + return &port.packets[0] +} + +// freePacket returns the first packet that is not pending handling or nil if all packets are pending. +func (port *udpPort) freePacket() *UDPPacket { + for i := range port.packets { + if !port.packets[i].pendingHandling() { + return &port.packets[i] + } + } + return nil +} + // UDP socket can be forced to respond even if no packet has been received // by flagging the packet's Rx time with non-zero value. var forcedTime = (time.Time{}).Add(1) -func (u *udpPort) forceResponse() (added bool) { - if !u.IsPendingHandling() { - added = true - u.packets[0].flagPendingNoPacket() +func (port *udpPort) forceResponse() (added bool) { + for i := range port.packets { + if !port.packets[i].pendingHandling() { + port.packets[i].flagPendingNoPacket() + return true + } } - return added + return false } type UDPPacket struct { @@ -94,33 +113,33 @@ type UDPPacket struct { payload [_MTU - eth.SizeEthernetHeader - eth.SizeIPv4Header - eth.SizeUDPHeader]byte } -func (u *UDPPacket) HasPacket() bool { return u.Rx != forcedTime && !u.Rx.IsZero() } -func (u *UDPPacket) pendingHandling() bool { return !u.Rx.IsZero() } -func (u *UDPPacket) invalidate() { u.Rx = time.Time{} } -func (u *UDPPacket) flagPendingNoPacket() { u.Rx = forcedTime } +func (pkt *UDPPacket) HasPacket() bool { return pkt.Rx != forcedTime && !pkt.Rx.IsZero() } +func (pkt *UDPPacket) pendingHandling() bool { return !pkt.Rx.IsZero() } +func (pkt *UDPPacket) invalidate() { pkt.Rx = time.Time{} } +func (pkt *UDPPacket) flagPendingNoPacket() { pkt.Rx = forcedTime } -func (p *UDPPacket) PutHeaders(b []byte) { +func (pkt *UDPPacket) PutHeaders(b []byte) { if len(b) < eth.SizeEthernetHeader+eth.SizeIPv4Header+eth.SizeUDPHeader { panic("short UDPPacket buffer") } - if p.IP.IHL() != 5 { + if pkt.IP.IHL() != 5 { panic("UDPPacket.PutHeaders expects no IP options") } - p.Eth.Put(b) - p.IP.Put(b[eth.SizeEthernetHeader:]) - p.UDP.Put(b[eth.SizeEthernetHeader+eth.SizeIPv4Header:]) + pkt.Eth.Put(b) + pkt.IP.Put(b[eth.SizeEthernetHeader:]) + pkt.UDP.Put(b[eth.SizeEthernetHeader+eth.SizeIPv4Header:]) } // Payload returns the UDP payload. If UDP or IPv4 header data is incorrect/bad it returns nil. // If the response is "forced" then payload will be nil. -func (p *UDPPacket) Payload() []byte { - if !p.HasPacket() { +func (pkt *UDPPacket) Payload() []byte { + if !pkt.HasPacket() { return nil } - ipLen := int(p.IP.TotalLength) - int(p.IP.IHL()*4) - eth.SizeUDPHeader // Total length(including header) - header length = payload length - uLen := int(p.UDP.Length) - eth.SizeUDPHeader - if ipLen != uLen || uLen > len(p.payload) { + ipLen := int(pkt.IP.TotalLength) - int(pkt.IP.IHL()*4) - eth.SizeUDPHeader // Total length(including header) - header length = payload length + uLen := int(pkt.UDP.Length) - eth.SizeUDPHeader + if ipLen != uLen || uLen > len(pkt.payload) { return nil // Mismatching IP and UDP data or bad length. } - return p.payload[:uLen] + return pkt.payload[:uLen] }