From f53b681798783e9153d3652bf87323389e619e5c Mon Sep 17 00:00:00 2001 From: Jayanth Varavani <1111446+jayanthvn@users.noreply.github.com> Date: Fri, 18 Aug 2023 01:20:33 -0700 Subject: [PATCH] API interface change and UTs for TC functions (#25) * UTs for tc functions * Fix vet * Add mocks * Feedback * PR feedbacl --- .github/workflows/pr-tests.yaml | 2 +- pkg/tc/generate_mocks.go | 15 ++ pkg/tc/mocks/tc_mocks.go | 104 ++++++++++ pkg/tc/tc.go | 68 ++++++- pkg/tc/tc_test.go | 327 ++++++++++++++++++++++++++++++++ 5 files changed, 506 insertions(+), 10 deletions(-) create mode 100644 pkg/tc/generate_mocks.go create mode 100644 pkg/tc/mocks/tc_mocks.go create mode 100644 pkg/tc/tc_test.go diff --git a/.github/workflows/pr-tests.yaml b/.github/workflows/pr-tests.yaml index afd82cc..652d183 100644 --- a/.github/workflows/pr-tests.yaml +++ b/.github/workflows/pr-tests.yaml @@ -33,6 +33,6 @@ jobs: - name: Build run: make build-linux - name: Unit test - run: make unit-test + run: sudo make unit-test - name: Upload code coverage uses: codecov/codecov-action@v3 diff --git a/pkg/tc/generate_mocks.go b/pkg/tc/generate_mocks.go new file mode 100644 index 0000000..ee00dc6 --- /dev/null +++ b/pkg/tc/generate_mocks.go @@ -0,0 +1,15 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +package tc + +//go:generate go run github.com/golang/mock/mockgen -destination mocks/tc_mocks.go . BpfTc diff --git a/pkg/tc/mocks/tc_mocks.go b/pkg/tc/mocks/tc_mocks.go new file mode 100644 index 0000000..333361a --- /dev/null +++ b/pkg/tc/mocks/tc_mocks.go @@ -0,0 +1,104 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/aws/aws-ebpf-sdk-go/pkg/tc (interfaces: BpfTc) + +// Package mock_tc is a generated GoMock package. +package mock_tc + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockBpfTc is a mock of BpfTc interface. +type MockBpfTc struct { + ctrl *gomock.Controller + recorder *MockBpfTcMockRecorder +} + +// MockBpfTcMockRecorder is the mock recorder for MockBpfTc. +type MockBpfTcMockRecorder struct { + mock *MockBpfTc +} + +// NewMockBpfTc creates a new mock instance. +func NewMockBpfTc(ctrl *gomock.Controller) *MockBpfTc { + mock := &MockBpfTc{ctrl: ctrl} + mock.recorder = &MockBpfTcMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBpfTc) EXPECT() *MockBpfTcMockRecorder { + return m.recorder +} + +// CleanupQdiscs mocks base method. +func (m *MockBpfTc) CleanupQdiscs(arg0, arg1 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CleanupQdiscs", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// CleanupQdiscs indicates an expected call of CleanupQdiscs. +func (mr *MockBpfTcMockRecorder) CleanupQdiscs(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupQdiscs", reflect.TypeOf((*MockBpfTc)(nil).CleanupQdiscs), arg0, arg1) +} + +// TCEgressAttach mocks base method. +func (m *MockBpfTc) TCEgressAttach(arg0 string, arg1 int, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TCEgressAttach", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// TCEgressAttach indicates an expected call of TCEgressAttach. +func (mr *MockBpfTcMockRecorder) TCEgressAttach(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TCEgressAttach", reflect.TypeOf((*MockBpfTc)(nil).TCEgressAttach), arg0, arg1, arg2) +} + +// TCEgressDetach mocks base method. +func (m *MockBpfTc) TCEgressDetach(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TCEgressDetach", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// TCEgressDetach indicates an expected call of TCEgressDetach. +func (mr *MockBpfTcMockRecorder) TCEgressDetach(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TCEgressDetach", reflect.TypeOf((*MockBpfTc)(nil).TCEgressDetach), arg0) +} + +// TCIngressAttach mocks base method. +func (m *MockBpfTc) TCIngressAttach(arg0 string, arg1 int, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TCIngressAttach", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// TCIngressAttach indicates an expected call of TCIngressAttach. +func (mr *MockBpfTcMockRecorder) TCIngressAttach(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TCIngressAttach", reflect.TypeOf((*MockBpfTc)(nil).TCIngressAttach), arg0, arg1, arg2) +} + +// TCIngressDetach mocks base method. +func (m *MockBpfTc) TCIngressDetach(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TCIngressDetach", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// TCIngressDetach indicates an expected call of TCIngressDetach. +func (mr *MockBpfTcMockRecorder) TCIngressDetach(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TCIngressDetach", reflect.TypeOf((*MockBpfTc)(nil).TCIngressDetach), arg0) +} diff --git a/pkg/tc/tc.go b/pkg/tc/tc.go index f3a4f01..f71320e 100644 --- a/pkg/tc/tc.go +++ b/pkg/tc/tc.go @@ -31,6 +31,28 @@ const ( var log = logger.Get() +type BpfTc interface { + TCIngressAttach(interfaceName string, progFD int, funcName string) error + TCIngressDetach(interfaceName string) error + TCEgressAttach(interfaceName string, progFD int, funcName string) error + TCEgressDetach(interfaceName string) error + CleanupQdiscs(ingressCleanup bool, egressCleanup bool) error + mismatchedInterfacePrefix(interfaceName string) error +} + +var _ BpfTc = &bpfTc{} + +type bpfTc struct { + InterfacePrefix string +} + +func New(interfacePrefix string) BpfTc { + return &bpfTc{ + InterfacePrefix: interfacePrefix, + } + +} + func enableQdisc(link netlink.Link) bool { qdiscs, err := netlink.QdiscList(link) if err != nil { @@ -54,7 +76,20 @@ func enableQdisc(link netlink.Link) bool { } -func TCIngressAttach(interfaceName string, progFD int, funcName string) error { +func (m *bpfTc) mismatchedInterfacePrefix(interfaceName string) error { + if !strings.HasPrefix(interfaceName, m.InterfacePrefix) { + log.Errorf("expected prefix - %s but got %s", m.InterfacePrefix, interfaceName) + return errors.New("Mismatched initialized prefix name and passed interface name") + } + return nil +} + +func (m *bpfTc) TCIngressAttach(interfaceName string, progFD int, funcName string) error { + + if err := m.mismatchedInterfacePrefix(interfaceName); err != nil { + return err + } + intf, err := netlink.LinkByName(interfaceName) if err != nil { log.Errorf("failed to find device by name %s: %w", interfaceName, err) @@ -101,7 +136,12 @@ func TCIngressAttach(interfaceName string, progFD int, funcName string) error { return nil } -func TCIngressDetach(interfaceName string) error { +func (m *bpfTc) TCIngressDetach(interfaceName string) error { + + if err := m.mismatchedInterfacePrefix(interfaceName); err != nil { + return err + } + intf, err := netlink.LinkByName(interfaceName) if err != nil { log.Errorf("failed to find device by name %s: %w", interfaceName, err) @@ -132,7 +172,12 @@ func TCIngressDetach(interfaceName string) error { return fmt.Errorf("no active filter to detach-%s", interfaceName) } -func TCEgressAttach(interfaceName string, progFD int, funcName string) error { +func (m *bpfTc) TCEgressAttach(interfaceName string, progFD int, funcName string) error { + + if err := m.mismatchedInterfacePrefix(interfaceName); err != nil { + return err + } + intf, err := netlink.LinkByName(interfaceName) if err != nil { log.Errorf("failed to find device by name %s: %w", interfaceName, err) @@ -179,7 +224,12 @@ func TCEgressAttach(interfaceName string, progFD int, funcName string) error { return nil } -func TCEgressDetach(interfaceName string) error { +func (m *bpfTc) TCEgressDetach(interfaceName string) error { + + if err := m.mismatchedInterfacePrefix(interfaceName); err != nil { + return err + } + intf, err := netlink.LinkByName(interfaceName) if err != nil { log.Errorf("failed to find device by name %s: %w", interfaceName, err) @@ -210,9 +260,9 @@ func TCEgressDetach(interfaceName string) error { return fmt.Errorf("no active filter to detach-%s", interfaceName) } -func CleanupQdiscs(prefix string, ingressCleanup bool, egressCleanup bool) error { +func (m *bpfTc) CleanupQdiscs(ingressCleanup bool, egressCleanup bool) error { - if prefix == "" { + if m.InterfacePrefix == "" { log.Errorf("invalid empty prefix") return nil } @@ -225,10 +275,10 @@ func CleanupQdiscs(prefix string, ingressCleanup bool, egressCleanup bool) error for _, link := range linkList { linkName := link.Attrs().Name - if strings.HasPrefix(linkName, prefix) { + if strings.HasPrefix(linkName, m.InterfacePrefix) { if ingressCleanup { log.Infof("Trying to cleanup ingress on %s", linkName) - err = TCIngressDetach(linkName) + err = m.TCIngressDetach(linkName) if err != nil { if err.Error() == FILTER_CLEANUP_FAILED { log.Errorf("failed to detach ingress, might not be present so moving on") @@ -238,7 +288,7 @@ func CleanupQdiscs(prefix string, ingressCleanup bool, egressCleanup bool) error if egressCleanup { log.Infof("Trying to cleanup egress on %s", linkName) - err = TCEgressDetach(linkName) + err = m.TCEgressDetach(linkName) if err != nil { if err.Error() == FILTER_CLEANUP_FAILED { log.Errorf("failed to detach egress, might not be present so moving on") diff --git a/pkg/tc/tc_test.go b/pkg/tc/tc_test.go new file mode 100644 index 0000000..fec7333 --- /dev/null +++ b/pkg/tc/tc_test.go @@ -0,0 +1,327 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +//limitations under the License. + +package tc + +import ( + "errors" + "fmt" + "os" + "syscall" + "testing" + + constdef "github.com/aws/aws-ebpf-sdk-go/pkg/constants" + "github.com/aws/aws-ebpf-sdk-go/pkg/elfparser" + mock_ebpf_maps "github.com/aws/aws-ebpf-sdk-go/pkg/maps/mocks" + mock_ebpf_progs "github.com/aws/aws-ebpf-sdk-go/pkg/progs/mocks" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/vishvananda/netlink" +) + +const ( + DUMMY_PROG_NAME = "test" +) + +type testMocks struct { + path string + ctrl *gomock.Controller + ebpf_progs *mock_ebpf_progs.MockBpfProgAPIs + ebpf_maps *mock_ebpf_maps.MockBpfMapAPIs + tcClient BpfTc +} + +func setup(t *testing.T, testPath string, interfacePrefix string) *testMocks { + ctrl := gomock.NewController(t) + return &testMocks{ + path: testPath, + ctrl: ctrl, + ebpf_progs: mock_ebpf_progs.NewMockBpfProgAPIs(ctrl), + ebpf_maps: mock_ebpf_maps.NewMockBpfMapAPIs(ctrl), + tcClient: New(interfacePrefix), + } +} + +func mount_bpf_fs() error { + fmt.Println("Let's mount BPF FS") + err := syscall.Mount("bpf", "/sys/fs/bpf", "bpf", 0, "mode=0700") + if err != nil { + fmt.Println("error mounting bpffs") + } + return err +} + +func unmount_bpf_fs() error { + fmt.Println("Let's unmount BPF FS") + err := syscall.Unmount("/sys/fs/bpf", 0) + if err != nil { + fmt.Println("error unmounting bpffs") + } + return err +} + +func setupTest(interfaceNames []string, t *testing.T) { + mount_bpf_fs() + for _, interfaceName := range interfaceNames { + linkAttr := netlink.LinkAttrs{Name: interfaceName} + linkIFB := netlink.Ifb{} + linkIFB.LinkAttrs = linkAttr + if err := netlink.LinkAdd(&linkIFB); err != nil { + assert.NoError(t, err) + } + } +} + +func teardownTest(interfaceNames []string, t *testing.T) { + unmount_bpf_fs() + //Cleanup link + for _, interfaceName := range interfaceNames { + linkAttr := netlink.LinkAttrs{Name: interfaceName} + linkIFB := netlink.Ifb{} + linkIFB.LinkAttrs = linkAttr + if err := netlink.LinkDel(&linkIFB); err != nil { + assert.NoError(t, err) + } + } +} + +func TestMismatchedPrefixName(t *testing.T) { + m := setup(t, "../../test-data/tc.bpf.elf", "eni") + defer m.ctrl.Finish() + + tests := []struct { + name string + interfaceName string + wantErr error + }{ + { + name: "Test Matched Prefix", + interfaceName: "eni1", + wantErr: nil, + }, + { + name: "Test Mismatched Prefix", + interfaceName: "fni1", + wantErr: errors.New("Mismatched initialized prefix name and passed interface name"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + testTcClient := New("eni") + err := testTcClient.mismatchedInterfacePrefix(tt.interfaceName) + if tt.wantErr != nil { + assert.EqualError(t, err, tt.wantErr.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestTCIngressAttachDetach(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("Test requires root privileges.") + } + + m := setup(t, "../../test-data/tc.bpf.elf", "f") + defer m.ctrl.Finish() + + interfaceName := "foo" + + var interfaceNames []string + interfaceNames = append(interfaceNames, interfaceName) + setupTest(interfaceNames, t) + defer teardownTest(interfaceNames, t) + + m.ebpf_maps.EXPECT().CreateBPFMap(gomock.Any()).AnyTimes() + m.ebpf_progs.EXPECT().LoadProg(gomock.Any()).AnyTimes() + m.ebpf_maps.EXPECT().PinMap(gomock.Any(), gomock.Any()).AnyTimes() + m.ebpf_maps.EXPECT().GetMapFromPinPath(gomock.Any()).AnyTimes() + m.ebpf_progs.EXPECT().GetProgFromPinPath(gomock.Any()).AnyTimes() + m.ebpf_progs.EXPECT().GetBPFProgAssociatedMapsIDs(gomock.Any()).AnyTimes() + + bpfSDKclient := elfparser.New() + progInfo, _, err := bpfSDKclient.LoadBpfFile(m.path, DUMMY_PROG_NAME) + if err != nil { + assert.NoError(t, err) + } + pinPath := constdef.PROG_BPF_FS + DUMMY_PROG_NAME + "_handle_ingress" + + progFD := progInfo[pinPath].Program.ProgFD + if err := m.tcClient.TCIngressAttach(interfaceName, progFD, DUMMY_PROG_NAME); err != nil { + assert.NoError(t, err) + } + + if err := m.tcClient.TCIngressDetach(interfaceName); err != nil { + assert.NoError(t, err) + } +} + +func TestTCEgressAttachDetach(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("Test requires root privileges.") + } + + m := setup(t, "../../test-data/tc.bpf.elf", "f") + defer m.ctrl.Finish() + + interfaceName := "foo" + + var interfaceNames []string + interfaceNames = append(interfaceNames, interfaceName) + + setupTest(interfaceNames, t) + defer teardownTest(interfaceNames, t) + + m.ebpf_maps.EXPECT().CreateBPFMap(gomock.Any()).AnyTimes() + m.ebpf_progs.EXPECT().LoadProg(gomock.Any()).AnyTimes() + m.ebpf_maps.EXPECT().PinMap(gomock.Any(), gomock.Any()).AnyTimes() + m.ebpf_maps.EXPECT().GetMapFromPinPath(gomock.Any()).AnyTimes() + m.ebpf_progs.EXPECT().GetProgFromPinPath(gomock.Any()).AnyTimes() + m.ebpf_progs.EXPECT().GetBPFProgAssociatedMapsIDs(gomock.Any()).AnyTimes() + + bpfSDKclient := elfparser.New() + progInfo, _, err := bpfSDKclient.LoadBpfFile(m.path, DUMMY_PROG_NAME) + if err != nil { + assert.NoError(t, err) + } + pinPath := constdef.PROG_BPF_FS + DUMMY_PROG_NAME + "_handle_ingress" + + progFD := progInfo[pinPath].Program.ProgFD + if err := m.tcClient.TCEgressAttach(interfaceName, progFD, DUMMY_PROG_NAME); err != nil { + assert.NoError(t, err) + } + + if err := m.tcClient.TCEgressDetach(interfaceName); err != nil { + assert.NoError(t, err) + } +} + +func TestQdiscCleanup(t *testing.T) { + + if os.Getuid() != 0 { + t.Skip("Test requires root privileges.") + } + + m := setup(t, "../../test-data/tc.bpf.elf", "eni") + defer m.ctrl.Finish() + + interfaceName1 := "eni1" + interfaceName2 := "eni2" + + var interfaceNames []string + interfaceNames = append(interfaceNames, interfaceName1) + interfaceNames = append(interfaceNames, interfaceName2) + + setupTest(interfaceNames, t) + defer teardownTest(interfaceNames, t) + + m.ebpf_maps.EXPECT().CreateBPFMap(gomock.Any()).AnyTimes() + m.ebpf_progs.EXPECT().LoadProg(gomock.Any()).AnyTimes() + m.ebpf_maps.EXPECT().PinMap(gomock.Any(), gomock.Any()).AnyTimes() + m.ebpf_maps.EXPECT().GetMapFromPinPath(gomock.Any()).AnyTimes() + m.ebpf_progs.EXPECT().GetProgFromPinPath(gomock.Any()).AnyTimes() + m.ebpf_progs.EXPECT().GetBPFProgAssociatedMapsIDs(gomock.Any()).AnyTimes() + + bpfSDKclient := elfparser.New() + progInfo, _, err := bpfSDKclient.LoadBpfFile(m.path, DUMMY_PROG_NAME) + if err != nil { + assert.NoError(t, err) + } + pinPath := constdef.PROG_BPF_FS + DUMMY_PROG_NAME + "_handle_ingress" + + progFD := progInfo[pinPath].Program.ProgFD + if err := m.tcClient.TCEgressAttach(interfaceName1, progFD, DUMMY_PROG_NAME); err != nil { + assert.NoError(t, err) + } + + if err := m.tcClient.TCIngressAttach(interfaceName2, progFD, DUMMY_PROG_NAME); err != nil { + assert.NoError(t, err) + } + + if err := m.tcClient.CleanupQdiscs(true, true); err != nil { + assert.NoError(t, err) + } +} + +func TestNetLinkAPIs(t *testing.T) { + + netLinktests := []struct { + name string + interfaceName string + overrideName bool + want []int + wantErr error + }{ + { + name: "Failed Link By Name", + interfaceName: "eni1", + want: nil, + overrideName: true, + wantErr: errors.New("Link not found"), + }, + { + name: "Failed to add filter", + interfaceName: "eni1", + overrideName: false, + want: nil, + wantErr: errors.New("invalid argument"), + }, + } + + for _, tt := range netLinktests { + t.Run(tt.name, func(t *testing.T) { + m := setup(t, "../../test-data/tc.bpf.elf", "eni") + defer m.ctrl.Finish() + + var interfaceNames []string + interfaceNames = append(interfaceNames, tt.interfaceName) + + setupTest(interfaceNames, t) + defer teardownTest(interfaceNames, t) + + m.ebpf_maps.EXPECT().CreateBPFMap(gomock.Any()).AnyTimes() + m.ebpf_progs.EXPECT().LoadProg(gomock.Any()).AnyTimes() + m.ebpf_maps.EXPECT().PinMap(gomock.Any(), gomock.Any()).AnyTimes() + m.ebpf_maps.EXPECT().GetMapFromPinPath(gomock.Any()).AnyTimes() + m.ebpf_progs.EXPECT().GetProgFromPinPath(gomock.Any()).AnyTimes() + m.ebpf_progs.EXPECT().GetBPFProgAssociatedMapsIDs(gomock.Any()).AnyTimes() + + bpfSDKclient := elfparser.New() + _, _, err := bpfSDKclient.LoadBpfFile(m.path, DUMMY_PROG_NAME) + if err != nil { + assert.NoError(t, err) + } + + intfName := tt.interfaceName + if tt.overrideName { + intfName = intfName + "10" + } + err = m.tcClient.TCEgressAttach(intfName, -1, "test") + if tt.wantErr != nil { + assert.EqualError(t, err, tt.wantErr.Error()) + } else { + assert.NoError(t, err) + } + err = m.tcClient.TCIngressAttach(intfName, -1, "test") + if tt.wantErr != nil { + assert.EqualError(t, err, tt.wantErr.Error()) + } else { + assert.NoError(t, err) + } + }) + } +}