diff --git a/go.mod b/go.mod index f4dd5be..9ad5764 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( ) require ( + github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect github.com/Microsoft/hcsshim v0.10.0-rc.1 // indirect github.com/containerd/cgroups v1.1.0 // indirect @@ -54,6 +55,7 @@ require ( github.com/moby/term v0.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/morikuni/aec v1.0.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/onsi/ginkgo/v2 v2.9.2 // indirect github.com/onsi/gomega v1.27.6 // indirect diff --git a/go.sum b/go.sum index 215a541..6d1735c 100644 --- a/go.sum +++ b/go.sum @@ -291,6 +291,7 @@ github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsr github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/cyphar/filepath-securejoin v0.2.2/go.mod h1:FpkQEhXnPnOthhzymB7CGsFk2G9VLXONKD9G7QGMM+4= github.com/cyphar/filepath-securejoin v0.2.3/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4= github.com/d2g/dhcp4 v0.0.0-20170904100407-a1d1b6c41b1c/go.mod h1:Ct2BUK8SB0YC1SMSibvLzxjeJLnrYEVLULFNiHY9YfQ= diff --git a/pkg/tracker/apitracker_test.go b/pkg/tracker/apitracker_test.go index bfc7714..a55ee09 100644 --- a/pkg/tracker/apitracker_test.go +++ b/pkg/tracker/apitracker_test.go @@ -24,6 +24,7 @@ import ( "github.com/docker/go-connections/nat" "github.com/rancher-sandbox/rancher-desktop-agent/pkg/tracker" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -46,7 +47,7 @@ func TestBasicAdd(t *testing.T) { mux.HandleFunc("/services/forwarder/expose", func(w http.ResponseWriter, r *http.Request) { err := json.NewDecoder(r.Body).Decode(&expectedExposeReq) - assert.NoError(t, err) + require.NoError(t, err) }) testSrv := httptest.NewServer(mux) @@ -62,7 +63,7 @@ func TestBasicAdd(t *testing.T) { }, } err := apiTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, expectedExposeReq.Local, ipPortBuilder(hostIP, hostPort)) assert.Equal(t, expectedExposeReq.Remote, ipPortBuilder(hostSwitchIP, hostPort)) @@ -81,7 +82,7 @@ func TestAddOverride(t *testing.T) { mux.HandleFunc("/services/forwarder/expose", func(w http.ResponseWriter, r *http.Request) { var tmpReq *types.ExposeRequest err := json.NewDecoder(r.Body).Decode(&tmpReq) - assert.NoError(t, err) + require.NoError(t, err) expectedExposeReq = append(expectedExposeReq, tmpReq) }) @@ -104,7 +105,7 @@ func TestAddOverride(t *testing.T) { }, } err := apiTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) assert.ElementsMatch(t, expectedExposeReq, []*types.ExposeRequest{ @@ -139,7 +140,7 @@ func TestAddOverride(t *testing.T) { }, } err = apiTracker.Add(containerID, portMapping2) - assert.NoError(t, err) + require.NoError(t, err) assert.ElementsMatch(t, expectedExposeReq, []*types.ExposeRequest{ @@ -167,7 +168,7 @@ func TestAddWithError(t *testing.T) { mux.HandleFunc("/services/forwarder/expose", func(w http.ResponseWriter, r *http.Request) { var tmpReq *types.ExposeRequest err := json.NewDecoder(r.Body).Decode(&tmpReq) - assert.NoError(t, err) + require.NoError(t, err) if tmpReq.Local == ipPortBuilder(hostIP2, hostPort) { http.Error(w, "Bad API error", http.StatusRequestTimeout) @@ -197,7 +198,7 @@ func TestAddWithError(t *testing.T) { }, } err := apiTracker.Add(containerID, portMapping) - assert.Error(t, err) + require.Error(t, err) errPortBinding := nat.PortBinding{ HostIP: hostIP2, @@ -208,7 +209,7 @@ func TestAddWithError(t *testing.T) { fmt.Errorf("exposing %+v failed: %w", errPortBinding, nestedErr), } expectedErr := fmt.Errorf("%w: %+v", tracker.ErrExposeAPI, errs) - assert.EqualError(t, err, expectedErr.Error()) + require.EqualError(t, err, expectedErr.Error()) assert.Len(t, expectedExposeReq, 2) assert.ElementsMatch(t, expectedExposeReq, @@ -236,7 +237,7 @@ func TestAddWithError(t *testing.T) { HostIP: hostIP2, HostPort: hostPort, }) - assert.Equal(t, actualPortMapping["80/tcp"], + assert.Equal(t, []nat.PortBinding{ { HostIP: hostIP, @@ -246,7 +247,7 @@ func TestAddWithError(t *testing.T) { HostIP: hostIP3, HostPort: hostPort, }, - }) + }, actualPortMapping["80/tcp"]) } func TestGet(t *testing.T) { @@ -276,7 +277,7 @@ func TestGet(t *testing.T) { apiTracker := tracker.NewAPITracker(testSrv.URL, true) err := apiTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) actualPortMappings := apiTracker.Get(containerID) assert.Len(t, actualPortMappings, len(portMapping)) @@ -296,7 +297,7 @@ func TestRemove(t *testing.T) { mux.HandleFunc("/services/forwarder/unexpose", func(w http.ResponseWriter, r *http.Request) { err := json.NewDecoder(r.Body).Decode(&expectedUnexposeReq) - assert.NoError(t, err) + require.NoError(t, err) }) testSrv := httptest.NewServer(mux) @@ -324,12 +325,12 @@ func TestRemove(t *testing.T) { }, } err := apiTracker.Add(containerID, portMapping1) - assert.NoError(t, err) + require.NoError(t, err) err = apiTracker.Add(containerID2, portMapping2) - assert.NoError(t, err) + require.NoError(t, err) err = apiTracker.Remove(containerID) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, expectedUnexposeReq.Local, ipPortBuilder(hostIP, hostPort)) @@ -354,7 +355,7 @@ func TestRemoveWithError(t *testing.T) { mux.HandleFunc("/services/forwarder/unexpose", func(w http.ResponseWriter, r *http.Request) { var tmpReq *types.UnexposeRequest err := json.NewDecoder(r.Body).Decode(&tmpReq) - assert.NoError(t, err) + require.NoError(t, err) if tmpReq.Local == ipPortBuilder(hostIP2, hostPort) { http.Error(w, "Test API error", http.StatusRequestTimeout) @@ -385,10 +386,10 @@ func TestRemoveWithError(t *testing.T) { }, } err := apiTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) err = apiTracker.Remove(containerID) - assert.Error(t, err) + require.Error(t, err) errPortBinding := nat.PortBinding{ HostIP: hostIP2, @@ -399,7 +400,7 @@ func TestRemoveWithError(t *testing.T) { fmt.Errorf("unexposing %+v failed: %w", errPortBinding, nestedErr), } expectedErr := fmt.Errorf("%w: %+v", tracker.ErrUnexposeAPI, errs) - assert.EqualError(t, err, expectedErr.Error()) + require.EqualError(t, err, expectedErr.Error()) assert.ElementsMatch(t, expectedUnexposeReq, []*types.UnexposeRequest{ {Local: ipPortBuilder(hostIP, hostPort)}, @@ -449,13 +450,13 @@ func TestRemoveAll(t *testing.T) { }, } err := apiTracker.Add(containerID, portMapping1) - assert.NoError(t, err) + require.NoError(t, err) err = apiTracker.Add(containerID2, portMapping2) - assert.NoError(t, err) + require.NoError(t, err) err = apiTracker.RemoveAll() - assert.NoError(t, err) + require.NoError(t, err) expectedPortMapping1 := apiTracker.Get(containerID) assert.Nil(t, expectedPortMapping1) @@ -478,7 +479,7 @@ func TestRemoveAllWithError(t *testing.T) { mux.HandleFunc("/services/forwarder/unexpose", func(w http.ResponseWriter, r *http.Request) { var tmpReq *types.UnexposeRequest err := json.NewDecoder(r.Body).Decode(&tmpReq) - assert.NoError(t, err) + require.NoError(t, err) if tmpReq.Local == ipPortBuilder(hostIP2, hostPort2) { http.Error(w, "RemoveAll API error", http.StatusRequestTimeout) @@ -513,13 +514,13 @@ func TestRemoveAllWithError(t *testing.T) { }, } err := apiTracker.Add(containerID, portMapping1) - assert.NoError(t, err) + require.NoError(t, err) err = apiTracker.Add(containerID2, portMapping2) - assert.NoError(t, err) + require.NoError(t, err) err = apiTracker.RemoveAll() - assert.Error(t, err) + require.Error(t, err) errPortBinding := nat.PortBinding{ HostIP: hostIP2, @@ -530,7 +531,7 @@ func TestRemoveAllWithError(t *testing.T) { fmt.Errorf("RemoveAll unexposing %+v failed: %w", errPortBinding, nestedErr), } expectedErr := fmt.Errorf("%w: %+v", tracker.ErrUnexposeAPI, errs) - assert.EqualError(t, err, expectedErr.Error()) + require.EqualError(t, err, expectedErr.Error()) assert.ElementsMatch(t, expectedUnexposeReq, []*types.UnexposeRequest{ {Local: ipPortBuilder(hostIP, hostPort)}, @@ -554,7 +555,7 @@ func TestNonAdminInstall(t *testing.T) { mux.HandleFunc("/services/forwarder/expose", func(w http.ResponseWriter, r *http.Request) { var tmpReq *types.ExposeRequest err := json.NewDecoder(r.Body).Decode(&tmpReq) - assert.NoError(t, err) + require.NoError(t, err) expectedExposeReq = append(expectedExposeReq, tmpReq) }) @@ -563,7 +564,7 @@ func TestNonAdminInstall(t *testing.T) { mux.HandleFunc("/services/forwarder/unexpose", func(w http.ResponseWriter, r *http.Request) { var tmpReq *types.UnexposeRequest err := json.NewDecoder(r.Body).Decode(&tmpReq) - assert.NoError(t, err) + require.NoError(t, err) expectedUnexposeReq = append(expectedUnexposeReq, tmpReq) }) @@ -582,7 +583,7 @@ func TestNonAdminInstall(t *testing.T) { } err := apiTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) assert.ElementsMatch(t, expectedExposeReq, []*types.ExposeRequest{ @@ -594,7 +595,7 @@ func TestNonAdminInstall(t *testing.T) { ) err = apiTracker.Remove(containerID) - assert.NoError(t, err) + require.NoError(t, err) assert.ElementsMatch(t, expectedUnexposeReq, []*types.UnexposeRequest{ diff --git a/pkg/tracker/listenertracker_test.go b/pkg/tracker/listenertracker_test.go index c1b220c..a8a3808 100644 --- a/pkg/tracker/listenertracker_test.go +++ b/pkg/tracker/listenertracker_test.go @@ -22,7 +22,7 @@ import ( "testing" "github.com/rancher-sandbox/rancher-desktop-agent/pkg/tracker" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestListenerTracker(t *testing.T) { @@ -46,16 +46,16 @@ func TestListenerTracker(t *testing.T) { t.Run(fmt.Sprintf("Should create listener with port: %d", testCase.testPort), func(t *testing.T) { t.Parallel() err := listenerTracker.AddListener(ctx, testIPAddr, testCase.testPort) - assert.Nil(t, err) + require.NoError(t, err) _, err = net.Dial("tcp", ipPortToAddr(testIPAddr, testCase.testPort)) - assert.Nil(t, err) + require.NoError(t, err) err = listenerTracker.RemoveListener(ctx, testIPAddr, testCase.testPort) - assert.Nil(t, err) + require.NoError(t, err) _, err = net.Dial("tcp", ipPortToAddr(testIPAddr, testCase.testPort)) - assert.ErrorIs(t, err, syscall.ECONNREFUSED) + require.ErrorIs(t, err, syscall.ECONNREFUSED) }) } } diff --git a/pkg/tracker/vtunneltracker_test.go b/pkg/tracker/vtunneltracker_test.go index b2a4b81..912e8d7 100644 --- a/pkg/tracker/vtunneltracker_test.go +++ b/pkg/tracker/vtunneltracker_test.go @@ -23,6 +23,7 @@ import ( "github.com/rancher-sandbox/rancher-desktop-agent/pkg/tracker" "github.com/rancher-sandbox/rancher-desktop-agent/pkg/types" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestVTunnelTrackerAdd(t *testing.T) { @@ -41,7 +42,7 @@ func TestVTunnelTrackerAdd(t *testing.T) { }, } err := vtunnelTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) portMapping2 := nat.PortMap{ "443/tcp": []nat.PortBinding{ @@ -52,7 +53,7 @@ func TestVTunnelTrackerAdd(t *testing.T) { }, } err = vtunnelTracker.Add(containerID2, portMapping2) - assert.NoError(t, err) + require.NoError(t, err) assert.ElementsMatch(t, forwarder.receivedPortMappings, []types.PortMapping{ @@ -96,7 +97,7 @@ func TestVTunnelTrackerAddOverride(t *testing.T) { }, } err := vtunnelTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) assert.ElementsMatch(t, forwarder.receivedPortMappings, []types.PortMapping{ @@ -126,16 +127,16 @@ func TestVTunnelTrackerAddOverride(t *testing.T) { } err = vtunnelTracker.Add(containerID, portMapping2) - assert.NoError(t, err) + require.NoError(t, err) secondCallIndex := 1 - assert.Equal(t, forwarder.receivedPortMappings[secondCallIndex], + assert.Equal(t, types.PortMapping{ Remove: false, Ports: portMapping2, ConnectAddrs: wslConnectAddr, }, - ) + forwarder.receivedPortMappings[secondCallIndex]) actualPortMapping = vtunnelTracker.Get(containerID) assert.Equal(t, actualPortMapping, portMapping2) @@ -151,12 +152,12 @@ func TestVTunnelTrackerAddEmptyPortMap(t *testing.T) { portMapping := nat.PortMap{} err := vtunnelTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) - assert.Len(t, forwarder.receivedPortMappings, 0) + assert.Empty(t, forwarder.receivedPortMappings, 0) actualPortMapping := vtunnelTracker.Get(containerID) - assert.Len(t, actualPortMapping, 0) + assert.Empty(t, actualPortMapping, 0) } func TestVTunnelTrackerAddWithError(t *testing.T) { @@ -176,7 +177,7 @@ func TestVTunnelTrackerAddWithError(t *testing.T) { }, } err := vtunnelTracker.Add(containerID, portMapping) - assert.ErrorIs(t, err, errSend) + require.ErrorIs(t, err, errSend) assert.ElementsMatch(t, forwarder.receivedPortMappings, []types.PortMapping{ @@ -188,7 +189,7 @@ func TestVTunnelTrackerAddWithError(t *testing.T) { }) actualPortMapping := vtunnelTracker.Get(containerID) - assert.Len(t, actualPortMapping, 0) + assert.Empty(t, actualPortMapping, 0) } func TestVTunnelTrackerRemove(t *testing.T) { @@ -207,7 +208,7 @@ func TestVTunnelTrackerRemove(t *testing.T) { }, } err := vtunnelTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) portMapping2 := nat.PortMap{ "443/tcp": []nat.PortBinding{ @@ -218,33 +219,33 @@ func TestVTunnelTrackerRemove(t *testing.T) { }, } err = vtunnelTracker.Add(containerID2, portMapping2) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, forwarder.receivedPortMappings, 2) err = vtunnelTracker.Remove(containerID) - assert.NoError(t, err) + require.NoError(t, err) removeRequestIndex := 2 - assert.Equal(t, forwarder.receivedPortMappings[removeRequestIndex], + assert.Equal(t, types.PortMapping{ Remove: true, Ports: portMapping, ConnectAddrs: wslConnectAddr, - }) + }, forwarder.receivedPortMappings[removeRequestIndex]) actualPortMapping := vtunnelTracker.Get(containerID) assert.Nil(t, actualPortMapping) actualPortMapping = vtunnelTracker.Get(containerID2) - assert.Equal(t, actualPortMapping, nat.PortMap{ + assert.Equal(t, nat.PortMap{ "443/tcp": []nat.PortBinding{ { HostIP: hostIP2, HostPort: hostPort2, }, }, - }) + }, actualPortMapping) } func TestVTunnelTrackerRemoveZeroLengthPortMap(t *testing.T) { @@ -256,10 +257,10 @@ func TestVTunnelTrackerRemoveZeroLengthPortMap(t *testing.T) { portMapping := nat.PortMap{} err := vtunnelTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) err = vtunnelTracker.Remove(containerID) - assert.NoError(t, err) + require.NoError(t, err) } func TestVTunnelTrackerRemoveError(t *testing.T) { @@ -278,7 +279,7 @@ func TestVTunnelTrackerRemoveError(t *testing.T) { }, } err := vtunnelTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) portMapping2 := nat.PortMap{ "443/tcp": []nat.PortBinding{ @@ -289,41 +290,41 @@ func TestVTunnelTrackerRemoveError(t *testing.T) { }, } err = vtunnelTracker.Add(containerID2, portMapping2) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, forwarder.receivedPortMappings, 2) forwarder.sendErr = errSend err = vtunnelTracker.Remove(containerID) - assert.Error(t, err) + require.Error(t, err) removeRequestIndex := 2 - assert.Equal(t, forwarder.receivedPortMappings[removeRequestIndex], + assert.Equal(t, types.PortMapping{ Remove: true, Ports: portMapping, ConnectAddrs: wslConnectAddr, - }) + }, forwarder.receivedPortMappings[removeRequestIndex]) actualPortMapping := vtunnelTracker.Get(containerID) - assert.Equal(t, actualPortMapping, nat.PortMap{ + assert.Equal(t, nat.PortMap{ "80/tcp": []nat.PortBinding{ { HostIP: hostIP, HostPort: hostPort, }, }, - }) + }, actualPortMapping) actualPortMapping = vtunnelTracker.Get(containerID2) - assert.Equal(t, actualPortMapping, nat.PortMap{ + assert.Equal(t, nat.PortMap{ "443/tcp": []nat.PortBinding{ { HostIP: hostIP2, HostPort: hostPort2, }, }, - }) + }, actualPortMapping) } func TestVTunnelTrackerRemoveAll(t *testing.T) { @@ -342,7 +343,7 @@ func TestVTunnelTrackerRemoveAll(t *testing.T) { }, } err := vtunnelTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) portMapping2 := nat.PortMap{ "443/tcp": []nat.PortBinding{ @@ -353,10 +354,10 @@ func TestVTunnelTrackerRemoveAll(t *testing.T) { }, } err = vtunnelTracker.Add(containerID2, portMapping2) - assert.NoError(t, err) + require.NoError(t, err) err = vtunnelTracker.RemoveAll() - assert.NoError(t, err) + require.NoError(t, err) actualPortMapping := vtunnelTracker.Get(containerID) assert.Nil(t, actualPortMapping) @@ -404,7 +405,7 @@ func TestVTunnelTrackerRemoveAllError(t *testing.T) { }, } err := vtunnelTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) portMapping2 := nat.PortMap{ "443/tcp": []nat.PortBinding{ @@ -415,7 +416,7 @@ func TestVTunnelTrackerRemoveAllError(t *testing.T) { }, } err = vtunnelTracker.Add(containerID2, portMapping2) - assert.NoError(t, err) + require.NoError(t, err) forwarder.failCondition = func(pm types.PortMapping) error { if _, ok := pm.Ports["443/tcp"]; ok { @@ -428,7 +429,7 @@ func TestVTunnelTrackerRemoveAllError(t *testing.T) { return nil } err = vtunnelTracker.RemoveAll() - assert.ErrorIs(t, err, tracker.ErrRemoveAll) + require.ErrorIs(t, err, tracker.ErrRemoveAll) actualPortMapping := vtunnelTracker.Get(containerID) assert.Nil(t, actualPortMapping) @@ -471,7 +472,7 @@ func TestVTunnelTrackerGet(t *testing.T) { }, } err := vtunnelTracker.Add(containerID, portMapping) - assert.NoError(t, err) + require.NoError(t, err) actualPortMap := vtunnelTracker.Get(containerID) assert.Equal(t, portMapping, actualPortMap)