diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..a1ea715e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +# MacOS files +.DS_Store diff --git a/README.md b/README.md index 91def7ed..830daa4e 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ Alpha tasks: - [x] Add IP Device abstraction - [x] Add IP Device implementation based on go-tun2socks (LWIP) - [ ] Add UDP handler to fallback to DNS-over-TCP - - [ ] Add DelegatePacketProxy + - [x] Add DelegatePacketProxy for runtime PacketProxy replacement ### Beta diff --git a/network/delegate_packet_proxy.go b/network/delegate_packet_proxy.go new file mode 100644 index 00000000..16e59e07 --- /dev/null +++ b/network/delegate_packet_proxy.go @@ -0,0 +1,70 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package network + +import ( + "errors" + "sync/atomic" +) + +// DelegatePacketProxy is a PacketProxy that forwards calls (like NewSession) to another PacketProxy. To create a +// DelegatePacketProxy with the default PacketProxy, use NewDelegatePacketProxy. To change the underlying PacketProxy, +// use SetProxy. +// +// Note: After changing the underlying PacketProxy, only new NewSession calls will be routed to the new PacketProxy. +// Existing sessions will not be affected. +// +// Multiple goroutines may invoke methods on a DelegatePacketProxy simultaneously. +type DelegatePacketProxy interface { + PacketProxy + + // SetProxy updates the underlying PacketProxy to `proxy`. And `proxy` must not be nil. After this function + // returns, all new PacketProxy calls will be forwarded to the `proxy`. Existing sessions will not be affected. + SetProxy(proxy PacketProxy) error +} + +var errInvalidProxy = errors.New("the underlying proxy must not be nil") + +// Compilation guard against interface implementation +var _ DelegatePacketProxy = (*delegatePacketProxy)(nil) + +type delegatePacketProxy struct { + proxy atomic.Value +} + +// NewDelegatePacketProxy creates a new [DelegatePacketProxy] that forwards calls to the `proxy` [PacketProxy]. +// The `proxy` must not be nil. +func NewDelegatePacketProxy(proxy PacketProxy) (DelegatePacketProxy, error) { + if proxy == nil { + return nil, errInvalidProxy + } + dp := delegatePacketProxy{} + dp.proxy.Store(proxy) + return &dp, nil +} + +// NewSession implements PacketProxy.NewSession, and it will forward the call to the underlying PacketProxy. +func (p *delegatePacketProxy) NewSession(respWriter PacketResponseReceiver) (PacketRequestSender, error) { + return p.proxy.Load().(PacketProxy).NewSession(respWriter) +} + +// SetProxy implements DelegatePacketProxy.SetProxy. +func (p *delegatePacketProxy) SetProxy(proxy PacketProxy) error { + if proxy == nil { + return errInvalidProxy + } + p.proxy.Store(proxy) + return nil +} diff --git a/network/delegate_packet_proxy_test.go b/network/delegate_packet_proxy_test.go new file mode 100644 index 00000000..fe23b584 --- /dev/null +++ b/network/delegate_packet_proxy_test.go @@ -0,0 +1,140 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package network + +import ( + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" +) + +// Make sure the underlying packet proxy can be initialized and updated +func TestProxyCanBeUpdated(t *testing.T) { + defProxy := &sessionCountPacketProxy{} + newProxy := &sessionCountPacketProxy{} + p, err := NewDelegatePacketProxy(defProxy) + require.NotNil(t, p) + require.NoError(t, err) + + // Initially no NewSession is called + require.Exactly(t, 0, defProxy.Count()) + require.Exactly(t, 0, newProxy.Count()) + + snd, err := p.NewSession(nil) + require.Nil(t, snd) + require.NoError(t, err) + + // defProxy.NewSession's count++ + require.Exactly(t, 1, defProxy.Count()) + require.Exactly(t, 0, newProxy.Count()) + + // SetProxy should not call NewSession + err = p.SetProxy(newProxy) + require.NoError(t, err) + require.Exactly(t, 1, defProxy.Count()) + require.Exactly(t, 0, newProxy.Count()) + + // newProxy.NewSession's count += 2 + snd, err = p.NewSession(nil) + require.Nil(t, snd) + require.NoError(t, err) + + snd, err = p.NewSession(nil) + require.Nil(t, snd) + require.NoError(t, err) + + require.Exactly(t, 1, defProxy.Count()) + require.Exactly(t, 2, newProxy.Count()) +} + +// Make sure multiple goroutines can call NewSession and SetProxy concurrently +// Need to run this test with `-race` flag +func TestSetProxyRaceCondition(t *testing.T) { + const proxiesCnt = 10 + const sessionCntPerProxy = 5 + + var proxies [proxiesCnt]*sessionCountPacketProxy + for i := 0; i < proxiesCnt; i++ { + proxies[i] = &sessionCountPacketProxy{} + } + + dp, err := NewDelegatePacketProxy(proxies[0]) + require.NotNil(t, dp) + require.NoError(t, err) + + setProxyTask := &sync.WaitGroup{} + cancelSetProxy := &atomic.Bool{} + setProxyTask.Add(1) + go func() { + for i := 0; !cancelSetProxy.Load(); i = (i + 1) % proxiesCnt { + err := dp.SetProxy(proxies[i]) + require.NoError(t, err) + } + setProxyTask.Done() + }() + + newSessionTask := &sync.WaitGroup{} + newSessionTask.Add(1) + go func() { + for i := 0; i < proxiesCnt*sessionCntPerProxy; i++ { + dp.NewSession(nil) + } + newSessionTask.Done() + }() + + newSessionTask.Wait() + cancelSetProxy.Store(true) + setProxyTask.Wait() + + expectedTotal := proxiesCnt * sessionCntPerProxy + actualTotal := 0 + for i := 0; i < proxiesCnt; i++ { + require.GreaterOrEqual(t, proxies[i].Count(), 0) + actualTotal += proxies[i].Count() + } + require.Equal(t, expectedTotal, actualTotal) +} + +// Make sure we cannot SetProxy to nil +func TestSetProxyWithNilValue(t *testing.T) { + // must not initialize with nil + dp, err := NewDelegatePacketProxy(nil) + require.Error(t, err) + require.Nil(t, dp) + + dp, err = NewDelegatePacketProxy(&sessionCountPacketProxy{}) + require.NoError(t, err) + require.NotNil(t, dp) + + // must not SetProxy to nil + err = dp.SetProxy(nil) + require.Error(t, err) +} + +// sessionCountPacketProxy logs the count of the NewSession calls, and returns a nil PacketRequestSender +type sessionCountPacketProxy struct { + cnt atomic.Int32 +} + +func (sp *sessionCountPacketProxy) NewSession(respWriter PacketResponseReceiver) (PacketRequestSender, error) { + sp.cnt.Add(1) + return nil, nil +} + +func (sp *sessionCountPacketProxy) Count() int { + return int(sp.cnt.Load()) +}