diff --git a/network/lwip2transport/udp.go b/network/lwip2transport/udp.go index 3666d18f..40a48923 100644 --- a/network/lwip2transport/udp.go +++ b/network/lwip2transport/udp.go @@ -17,6 +17,7 @@ package lwip2transport import ( "net" "sync" + "sync/atomic" "github.com/Jigsaw-Code/outline-internal-sdk/network" lwip "github.com/eycorsican/go-tun2socks/core" @@ -69,7 +70,10 @@ func (h *udpHandler) ReceiveTo(tunConn lwip.UDPConn, data []byte, destAddr *net. // newSession creates a new PacketRequestSender related to conn. The caller needs to put the new PacketRequestSender // to the h.senders map. func (h *udpHandler) newSession(conn lwip.UDPConn) (network.PacketRequestSender, error) { - respWriter := &udpConnResponseWriter{conn, h} + respWriter := &udpConnResponseWriter{ + conn: conn, + h: h, + } reqSender, err := h.proxy.NewSession(respWriter) if err != nil { respWriter.Close() @@ -93,12 +97,17 @@ func (h *udpHandler) closeSession(conn lwip.UDPConn) error { // The PacketResponseWriter that will write responses to the lwip network stack. type udpConnResponseWriter struct { - conn lwip.UDPConn - h *udpHandler + closed atomic.Bool + conn lwip.UDPConn + h *udpHandler } // Write relays packets from the proxy to the lwIP TUN device. func (r *udpConnResponseWriter) WriteFrom(p []byte, source net.Addr) (int, error) { + if r.closed.Load() { + return 0, network.ErrClosed + } + // net.Addr -> *net.UDPAddr, because r.conn.WriteFrom requires *net.UDPAddr // and this is more reliable than type assertion // also the source address host will be an IP address, no actual resolution will be done @@ -106,10 +115,14 @@ func (r *udpConnResponseWriter) WriteFrom(p []byte, source net.Addr) (int, error if err != nil { return 0, err } + return r.conn.WriteFrom(p, srcAddr) } // Close informs the udpHandler to close the UDPConn and clean up the UDP session. func (r *udpConnResponseWriter) Close() error { - return r.h.closeSession(r.conn) + if r.closed.CompareAndSwap(false, true) { + return r.h.closeSession(r.conn) + } + return network.ErrClosed } diff --git a/network/lwip2transport/udp_test.go b/network/lwip2transport/udp_test.go new file mode 100644 index 00000000..5450f4e8 --- /dev/null +++ b/network/lwip2transport/udp_test.go @@ -0,0 +1,87 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lwip2transport + +import ( + "errors" + "net" + "net/netip" + "testing" + + "github.com/Jigsaw-Code/outline-internal-sdk/network" + "github.com/stretchr/testify/require" +) + +// Make sure we can successfully Close the request sender and response receiver wihout deadlock +func TestUDPResponseWriterCloseNoDeadlock(t *testing.T) { + proxy := &noopSingleSessionPacketProxy{} + h := newUDPHandler(proxy) + + // Create one and only one session in the proxy + localAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:60127")) + destAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:4321")) + err := h.ReceiveTo(&noopLwIPUDPConn{localAddr}, []byte{}, destAddr) + require.NoError(t, err) + + // Close this single session (i.e. the request sender), it will close proxy.respWriter + // udpHandler must make sure only one `Close()` is called, and there should be no deadlocks + err = proxy.respWriter.Close() + require.NoError(t, err) + require.Exactly(t, 1, proxy.closeCnt) +} + +/********** Test Utilities **********/ + +type noopSingleSessionPacketProxy struct { + closeCnt int + respWriter network.PacketResponseReceiver +} + +func (p *noopSingleSessionPacketProxy) NewSession(respWriter network.PacketResponseReceiver) (network.PacketRequestSender, error) { + if p.respWriter != nil { + return nil, errors.New("don't support multiple sessions in this proxy") + } + p.respWriter = respWriter + return p, nil +} + +func (p *noopSingleSessionPacketProxy) Close() error { + p.closeCnt++ + return p.respWriter.Close() +} + +func (p *noopSingleSessionPacketProxy) WriteTo([]byte, netip.AddrPort) (int, error) { + return 0, nil +} + +type noopLwIPUDPConn struct { + localAddr *net.UDPAddr +} + +func (*noopLwIPUDPConn) Close() error { + return nil +} + +func (conn *noopLwIPUDPConn) LocalAddr() *net.UDPAddr { + return conn.localAddr +} + +func (*noopLwIPUDPConn) ReceiveTo(data []byte, addr *net.UDPAddr) error { + return nil +} + +func (*noopLwIPUDPConn) WriteFrom(data []byte, addr *net.UDPAddr) (int, error) { + return 0, nil +}