diff --git a/controllers/drain_controller_test.go b/controllers/drain_controller_test.go new file mode 100644 index 000000000..7f500ffbf --- /dev/null +++ b/controllers/drain_controller_test.go @@ -0,0 +1,126 @@ +package controllers + +import ( + goctx "context" + "time" + + v1 "k8s.io/api/core/v1" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + sriovnetworkv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" + consts "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/consts" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/utils" + "github.com/k8snetworkplumbingwg/sriov-network-operator/test/util" +) + +func createNode(name, anno string) *v1.Node { + node := &v1.Node{} + node.Name = name + node.Annotations = map[string]string{} + node.Annotations[consts.NodeDrainAnnotation] = anno + + Expect(k8sClient.Create(goctx.TODO(), node)).Should(Succeed()) + return node +} + +var _ = Describe("Drain Controller", func() { + + Context("Parallel nodes draining", func() { + + It("Should drain one node", func() { + node1 := createNode("node1", "Drain_Required") + node2 := createNode("node2", "Drain_Required") + + config := &sriovnetworkv1.SriovOperatorConfig{} + err := util.WaitForNamespacedObject(config, k8sClient, testNamespace, "default", interval, timeout) + Expect(err).NotTo(HaveOccurred()) + config.Spec = sriovnetworkv1.SriovOperatorConfigSpec{ + MaxParallelNodeConfiguration: 1, + } + updateErr := k8sClient.Update(goctx.TODO(), config) + Expect(updateErr).NotTo(HaveOccurred()) + time.Sleep(3 * time.Second) + + nodeList := &v1.NodeList{} + listErr := k8sClient.List(ctx, nodeList) + Expect(listErr).NotTo(HaveOccurred()) + + drainingNodes := 0 + for _, node := range nodeList.Items { + if utils.NodeHasAnnotation(node, "sriovnetwork.openshift.io/state", "Draining") { + drainingNodes++ + } + } + Expect(drainingNodes).To(Equal(1)) + + err = k8sClient.Delete(goctx.TODO(), node1) + Expect(err).NotTo(HaveOccurred()) + err = k8sClient.Delete(goctx.TODO(), node2) + Expect(err).NotTo(HaveOccurred()) + }) + + It("Should drain two nodes", func() { + node1 := createNode("node1", "Drain_Required") + node2 := createNode("node2", "Drain_Required") + + config := &sriovnetworkv1.SriovOperatorConfig{} + err := util.WaitForNamespacedObject(config, k8sClient, testNamespace, "default", interval, timeout) + Expect(err).NotTo(HaveOccurred()) + config.Spec = sriovnetworkv1.SriovOperatorConfigSpec{ + MaxParallelNodeConfiguration: 2, + } + updateErr := k8sClient.Update(goctx.TODO(), config) + Expect(updateErr).NotTo(HaveOccurred()) + time.Sleep(3 * time.Second) + + nodeList := &v1.NodeList{} + listErr := k8sClient.List(ctx, nodeList) + Expect(listErr).NotTo(HaveOccurred()) + + drainingNodes := 0 + for _, node := range nodeList.Items { + Expect(utils.NodeHasAnnotation(node, "sriovnetwork.openshift.io/state", "Draining")).To(BeTrue()) + } + Expect(drainingNodes).To(Equal(1)) + + err = k8sClient.Delete(goctx.TODO(), node1) + Expect(err).NotTo(HaveOccurred()) + err = k8sClient.Delete(goctx.TODO(), node2) + Expect(err).NotTo(HaveOccurred()) + + }) + + It("Should drain all nodes", func() { + node1 := createNode("node1", "Drain_Required") + node2 := createNode("node2", "Drain_Required") + + config := &sriovnetworkv1.SriovOperatorConfig{} + err := util.WaitForNamespacedObject(config, k8sClient, testNamespace, "default", interval, timeout) + Expect(err).NotTo(HaveOccurred()) + config.Spec = sriovnetworkv1.SriovOperatorConfigSpec{ + MaxParallelNodeConfiguration: 0, + } + updateErr := k8sClient.Update(goctx.TODO(), config) + Expect(updateErr).NotTo(HaveOccurred()) + time.Sleep(3 * time.Second) + + nodeList := &v1.NodeList{} + listErr := k8sClient.List(ctx, nodeList) + Expect(listErr).NotTo(HaveOccurred()) + + drainingNodes := 0 + for _, node := range nodeList.Items { + Expect(utils.NodeHasAnnotation(node, "sriovnetwork.openshift.io/state", "Draining")).To(BeTrue()) + } + Expect(drainingNodes).To(Equal(1)) + + err = k8sClient.Delete(goctx.TODO(), node1) + Expect(err).NotTo(HaveOccurred()) + err = k8sClient.Delete(goctx.TODO(), node2) + Expect(err).NotTo(HaveOccurred()) + + }) + }) +}) diff --git a/pkg/daemon/daemon.go b/pkg/daemon/daemon.go index 4c653a17d..bfe25fd27 100644 --- a/pkg/daemon/daemon.go +++ b/pkg/daemon/daemon.go @@ -38,6 +38,7 @@ import ( sriovnetworkv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" snclientset "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/client/clientset/versioned" sninformer "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/client/informers/externalversions" + consts "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/consts" plugin "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/plugins" "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/utils" ) @@ -562,7 +563,7 @@ func (dn *Daemon) nodeStateSyncHandler() error { // isNodeDraining: check if the node is draining // both Draining and MCP paused labels will return true func (dn *Daemon) isNodeDraining() bool { - anno, ok := dn.node.Annotations[annoKey] + anno, ok := dn.node.Annotations[consts.NodeDrainAnnotation] if !ok { return false } @@ -689,8 +690,8 @@ func (dn *Daemon) annotateNode(node, value string) error { if newNode.Annotations == nil { newNode.Annotations = map[string]string{} } - if newNode.Annotations[annoKey] != value { - newNode.Annotations[annoKey] = value + if newNode.Annotations[consts.NodeDrainAnnotation] != value { + newNode.Annotations[consts.NodeDrainAnnotation] = value newData, err := json.Marshal(newNode) if err != nil { return err