diff --git a/stack/socket_tcp.go b/stack/socket_tcp.go index 75843f4..c72f9ac 100644 --- a/stack/socket_tcp.go +++ b/stack/socket_tcp.go @@ -1,6 +1,7 @@ package stack import ( + "cmp" "errors" "io" "net/netip" @@ -19,6 +20,7 @@ type tcp struct { remote netip.AddrPort remoteMAC [6]byte tx ring + rx ring } func (t *tcp) State() seqs.State { @@ -28,7 +30,7 @@ func (t *tcp) State() seqs.State { func (t *tcp) Send(b []byte) error { if t.tx.buf == nil { t.tx = ring{ - buf: make([]byte, 2048), + buf: make([]byte, max(2048, len(b))), } } _, err := t.tx.Write(b) @@ -132,12 +134,30 @@ func (t *tcp) handleRecv(response []byte, pkt *TCPPacket) (n int, err error) { } func (t *tcp) handleUser(response []byte, pkt *TCPPacket) (n int, err error) { - return 0, nil -} + available := t.tx.Buffered() + if available == 0 { + return 0, nil // No data to send. + } + seg, ok := t.scb.PendingSegment(available) + if !ok { + return 0, errors.New("possible segment not found") // No pending control segment. Yield to handleUser. + } + err = t.scb.Send(seg) + if err != nil { + return 0, err + } + t.setSrcDest(pkt) + pkt.PutHeaders(response) + payloadPlace := response[54:] + n, err = t.tx.Read(payloadPlace[:seg.DATALEN]) + if err != nil || n != int(seg.DATALEN) { + panic("bug in handleUser") // This is a bug in ring buffer or a race condition. + } -func (t *tcp) handleInitSyn(response []byte, pkt *TCPPacket) (n int, err error) { - // Uninitialized TCB, we start the handshake. + return 54 + n, err +} +func (t *tcp) setSrcDest(pkt *TCPPacket) { copy(pkt.Eth.Source[:], t.stack.MAC) pkt.IP.Source = t.stack.IP.As4() pkt.TCP.SourcePort = t.localPort @@ -145,7 +165,11 @@ func (t *tcp) handleInitSyn(response []byte, pkt *TCPPacket) (n int, err error) pkt.IP.Destination = t.remote.Addr().As4() pkt.TCP.DestinationPort = t.remote.Port() pkt.Eth.Destination = t.remoteMAC +} +func (t *tcp) handleInitSyn(response []byte, pkt *TCPPacket) (n int, err error) { + // Uninitialized TCB, we start the handshake. + t.setSrcDest(pkt) pkt.CalculateHeaders(t.synsentSegment(), nil) pkt.PutHeaders(response) return 54, nil @@ -260,3 +284,17 @@ func (r *ring) onReadEnd() { r.end = 0 } } + +func max[T cmp.Ordered](a, b T) T { + if a > b { + return a + } + return b +} + +func min[T cmp.Ordered](a, b T) T { + if a < b { + return a + } + return b +}