Skip to content
This repository has been archived by the owner on Apr 3, 2021. It is now read-only.

Commit

Permalink
turn TCPConn into a net.Conn
Browse files Browse the repository at this point in the history
  • Loading branch information
eycorsican committed May 10, 2019
1 parent 1d49ed1 commit 643323f
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 388 deletions.
11 changes: 11 additions & 0 deletions core/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package core

import (
"net"
"time"
)

// TCPConn abstracts a TCP connection comming from TUN. This connection
Expand All @@ -16,6 +17,12 @@ type TCPConn interface {
// Receive receives data from TUN.
Receive(data []byte) error

// Read reads data comming from TUN, note that it reads from an
// underlying pipe that the writer writes in the lwip thread,
// write op blocks until previous written data is consumed, one
// should read out all data as soon as possible.
Read(data []byte) (int, error)

// Write writes data to TUN.
Write(data []byte) (int, error)

Expand All @@ -36,6 +43,10 @@ type TCPConn interface {

// Poll will be periodically called by timers.
Poll() error

SetDeadline(t time.Time) error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
}

// TCPConn abstracts a UDP connection comming from TUN. This connection
Expand Down
13 changes: 2 additions & 11 deletions core/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,8 @@ import (

// TCPConnHandler handles TCP connections comming from TUN.
type TCPConnHandler interface {
// Connect connects the proxy server.
Connect(conn TCPConn, target net.Addr) error

// DidReceive will be called when data arrives from TUN.
DidReceive(conn TCPConn, data []byte) error

// DidClose will be called when the connection has been closed.
DidClose(conn TCPConn)

// LocalDidClose will be called when local client has close the connection.
LocalDidClose(conn TCPConn)
// Handle handles the conn for target.
Handle(conn net.Conn, target net.Addr) error
}

// UDPConnHandler handles UDP connections comming from TUN.
Expand Down
83 changes: 58 additions & 25 deletions core/tcp_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import "C"
import (
"errors"
"fmt"
"io"
"math/rand"
"net"
"sync"
"time"
"unsafe"
)

Expand All @@ -21,6 +23,7 @@ const (
tcpConnecting
tcpConnected
tcpClosing
tcpClosed
tcpLocalClosed
tcpAborting
tcpErrored
Expand All @@ -29,14 +32,18 @@ const (
type tcpConn struct {
sync.Mutex

pcb *C.struct_tcp_pcb
handler TCPConnHandler
remoteAddr net.Addr
localAddr net.Addr
connKeyArg unsafe.Pointer
connKey uint32
canWrite *sync.Cond // Condition variable to implement TCP backpressure.
state tcpConnState
pcb *C.struct_tcp_pcb
handler TCPConnHandler
remoteAddr net.Addr
localAddr net.Addr
connKeyArg unsafe.Pointer
connKey uint32
canWrite *sync.Cond // Condition variable to implement TCP backpressure.
state tcpConnState
sndPipeReader *io.PipeReader
sndPipeWriter *io.PipeWriter
closeOnce sync.Once
closeErr error
}

func newTCPConn(pcb *C.struct_tcp_pcb, handler TCPConnHandler) (TCPConn, error) {
Expand All @@ -53,15 +60,18 @@ func newTCPConn(pcb *C.struct_tcp_pcb, handler TCPConnHandler) (TCPConn, error)
setTCPErrCallback(pcb)
setTCPPollCallback(pcb, C.u8_t(TCP_POLL_INTERVAL))

pipeReader, pipeWriter := io.Pipe()
conn := &tcpConn{
pcb: pcb,
handler: handler,
localAddr: ParseTCPAddr(ipAddrNTOA(pcb.remote_ip), uint16(pcb.remote_port)),
remoteAddr: ParseTCPAddr(ipAddrNTOA(pcb.local_ip), uint16(pcb.local_port)),
connKeyArg: connKeyArg,
connKey: connKey,
canWrite: sync.NewCond(&sync.Mutex{}),
state: tcpNewConn,
pcb: pcb,
handler: handler,
localAddr: ParseTCPAddr(ipAddrNTOA(pcb.remote_ip), uint16(pcb.remote_port)),
remoteAddr: ParseTCPAddr(ipAddrNTOA(pcb.local_ip), uint16(pcb.local_port)),
connKeyArg: connKeyArg,
connKey: connKey,
canWrite: sync.NewCond(&sync.Mutex{}),
state: tcpNewConn,
sndPipeReader: pipeReader,
sndPipeWriter: pipeWriter,
}

// Associate conn with key and save to the global map.
Expand All @@ -73,7 +83,7 @@ func newTCPConn(pcb *C.struct_tcp_pcb, handler TCPConnHandler) (TCPConn, error)
conn.state = tcpConnecting
conn.Unlock()
go func() {
err := handler.Connect(conn, conn.RemoteAddr())
err := handler.Handle(TCPConn(conn), conn.RemoteAddr())
if err != nil {
conn.Abort()
} else {
Expand All @@ -94,6 +104,16 @@ func (conn *tcpConn) LocalAddr() net.Addr {
return conn.localAddr
}

func (conn *tcpConn) SetDeadline(t time.Time) error {
return nil
}
func (conn *tcpConn) SetReadDeadline(t time.Time) error {
return nil
}
func (conn *tcpConn) SetWriteDeadline(t time.Time) error {
return nil
}

func (conn *tcpConn) receiveCheck() error {
conn.Lock()
defer conn.Unlock()
Expand Down Expand Up @@ -121,16 +141,18 @@ func (conn *tcpConn) Receive(data []byte) error {
if err := conn.receiveCheck(); err != nil {
return err
}
err := conn.handler.DidReceive(conn, data)
n, err := conn.sndPipeWriter.Write(data)
if err != nil {
conn.abortInternal()
conn.canWrite.Broadcast()
return NewLWIPError(LWIP_ERR_ABRT)
return NewLWIPError(LWIP_ERR_CONN)
}
C.tcp_recved(conn.pcb, C.u16_t(len(data)))
C.tcp_recved(conn.pcb, C.u16_t(n))
return NewLWIPError(LWIP_ERR_OK)
}

func (conn *tcpConn) Read(data []byte) (int, error) {
return conn.sndPipeReader.Read(data)
}

// writeInternal enqueues data to snd_buf, and treats ERR_MEM returned by tcp_write not an error,
// but instead tells the caller that data is not successfully enqueued, and should try
// again another time. By calling this function, the lwIP thread is assumed to be already
Expand All @@ -155,6 +177,10 @@ func (conn *tcpConn) writeCheck() error {
return fmt.Errorf("connection %v->%v encountered a fatal error", conn.LocalAddr(), conn.RemoteAddr())
case tcpAborting:
return fmt.Errorf("connection %v->%v is aborting", conn.LocalAddr(), conn.RemoteAddr())
case tcpClosing:
return fmt.Errorf("connection %v->%v is closing", conn.LocalAddr(), conn.RemoteAddr())
case tcpClosed:
return fmt.Errorf("connection %v->%v was closed", conn.LocalAddr(), conn.RemoteAddr())
case tcpLocalClosed:
return fmt.Errorf("connection %v->%v was closed by local", conn.LocalAddr(), conn.RemoteAddr())
case tcpConnected:
Expand Down Expand Up @@ -229,6 +255,13 @@ func (conn *tcpConn) CheckState() error {
}

func (conn *tcpConn) Close() error {
conn.closeOnce.Do(conn.close)
return conn.closeErr
}

func (conn *tcpConn) close() {
conn.sndPipeWriter.Close()

lwipMutex.Lock()
C.tcp_shutdown(conn.pcb, 0, 1) // Close the TX side ASAP.
lwipMutex.Unlock()
Expand All @@ -239,7 +272,7 @@ func (conn *tcpConn) Close() error {
// Close maybe called outside of lwIP thread, we should not call tcp_close() in this
// function, instead just make a flag to indicate we are closing the connection.
conn.state = tcpClosing
return nil
conn.closeErr = nil
}

func (conn *tcpConn) setLocalClosed() error {
Expand Down Expand Up @@ -295,12 +328,10 @@ func (conn *tcpConn) Abort() {
// The corresponding pcb is already freed when this callback is called
func (conn *tcpConn) Err(err error) {
conn.release()
conn.handler.DidClose(conn)
conn.setErrored()
}

func (conn *tcpConn) LocalDidClose() error {
conn.handler.LocalDidClose(conn)
conn.setLocalClosed()
return conn.CheckState()
}
Expand All @@ -310,6 +341,8 @@ func (conn *tcpConn) release() {
freeConnKeyArg(conn.connKeyArg)
tcpConns.Delete(conn.connKey)
}
conn.sndPipeReader.Close()
conn.state = tcpClosed
}

func (conn *tcpConn) Poll() error {
Expand Down
71 changes: 18 additions & 53 deletions proxy/direct/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,70 +12,35 @@ import (
"github.com/eycorsican/go-tun2socks/core"
)

type tcpHandler struct {
sync.Mutex
conns map[core.TCPConn]net.Conn
}
type tcpHandler struct{}

func NewTCPHandler() core.TCPConnHandler {
return &tcpHandler{
conns: make(map[core.TCPConn]net.Conn, 16),
}
return &tcpHandler{}
}

func (h *tcpHandler) fetchInput(conn core.TCPConn, input io.Reader) {
_, err := io.Copy(conn, input)
if err != nil {
h.Close(conn)
func (h *tcpHandler) handleInput(conn net.Conn, input io.ReadCloser) {
defer func() {
conn.Close()
}
input.Close()
}()
io.Copy(conn, input)
}

func (h *tcpHandler) Connect(conn core.TCPConn, target net.Addr) error {
func (h *tcpHandler) handleOutput(conn net.Conn, output io.WriteCloser) {
defer func() {
conn.Close()
output.Close()
}()
io.Copy(output, conn)
}

func (h *tcpHandler) Handle(conn net.Conn, target net.Addr) error {
c, err := net.Dial("tcp", target.String())
if err != nil {
return err
}
h.Lock()
h.conns[conn] = c
h.Unlock()
c.SetReadDeadline(time.Time{})
go h.fetchInput(conn, c)
go h.handleInput(conn, c)
go h.handleOutput(conn, c)
log.Infof("new proxy connection for target: %s:%s", target.Network(), target.String())
return nil
}

func (h *tcpHandler) DidReceive(conn core.TCPConn, data []byte) error {
h.Lock()
c, ok := h.conns[conn]
h.Unlock()
if ok {
_, err := c.Write(data)
if err != nil {
h.Close(conn)
return errors.New(fmt.Sprintf("write remote failed: %v", err))
}
return nil
} else {
h.Close(conn)
return errors.New(fmt.Sprintf("proxy connection %v->%v does not exists", conn.LocalAddr(), conn.RemoteAddr()))
}
}

func (h *tcpHandler) DidClose(conn core.TCPConn) {
h.Close(conn)
}

func (h *tcpHandler) LocalDidClose(conn core.TCPConn) {
h.Close(conn)
}

func (h *tcpHandler) Close(conn core.TCPConn) {
h.Lock()
defer h.Unlock()

if c, found := h.conns[conn]; found {
c.Close()
}
delete(h.conns, conn)
}
48 changes: 7 additions & 41 deletions proxy/echo/tcp.go
Original file line number Diff line number Diff line change
@@ -1,60 +1,26 @@
package echo

import (
"io"
"net"

"github.com/eycorsican/go-tun2socks/core"
)

var bufSize = 10 * 1024

type connEntry struct {
data []byte
conn core.TCPConn
}

// An echo proxy, do nothing but echo back data to the sender, the handler was
// created for testing purposes, it may causes issues when more than one clients
// are connecting the handler simultaneously.
type tcpHandler struct {
buf chan *connEntry
}
type tcpHandler struct{}

func NewTCPHandler() core.TCPConnHandler {
handler := &tcpHandler{
buf: make(chan *connEntry, bufSize),
}
go handler.echoBack()
return handler
return &tcpHandler{}
}

func (h *tcpHandler) echoBack() {
for {
e := <-h.buf
_, err := e.conn.Write(e.data)
if err != nil {
e.conn.Close()
}
}
func (h *tcpHandler) echoBack(conn net.Conn) {
io.Copy(conn, conn)
}

func (h *tcpHandler) Connect(conn core.TCPConn, target net.Addr) error {
func (h *tcpHandler) Handle(conn net.Conn, target net.Addr) error {
go h.echoBack(conn)
return nil
}

func (h *tcpHandler) DidReceive(conn core.TCPConn, data []byte) error {
payload := append([]byte(nil), data...)
// This function runs in lwIP thread, we can't block, so discarding data if
// buf if full.
select {
case h.buf <- &connEntry{data: payload, conn: conn}:
default:
}
return nil
}

func (h *tcpHandler) DidClose(conn core.TCPConn) {
}

func (h *tcpHandler) LocalDidClose(conn core.TCPConn) {
}
Loading

0 comments on commit 643323f

Please sign in to comment.