diff --git a/controllers/drain_controller.go b/controllers/drain_controller.go index e0f983d4e..25def6a79 100644 --- a/controllers/drain_controller.go +++ b/controllers/drain_controller.go @@ -3,6 +3,9 @@ package controllers import ( "context" "fmt" + "sort" + "strings" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" @@ -16,8 +19,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/controller-runtime/pkg/source" - "sort" - "strings" sriovnetworkv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" constants "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/consts" diff --git a/controllers/drain_controller_test.go b/controllers/drain_controller_test.go new file mode 100644 index 000000000..c0fe7f226 --- /dev/null +++ b/controllers/drain_controller_test.go @@ -0,0 +1,120 @@ +package controllers + +import ( + goctx "context" + "time" + + v1 "k8s.io/api/core/v1" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + sriovnetworkv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" + consts "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/consts" + "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/utils" + "github.com/k8snetworkplumbingwg/sriov-network-operator/test/util" +) + +func createNodeObj(name, anno string) *v1.Node { + node := &v1.Node{} + node.Name = name + node.Annotations = map[string]string{} + node.Annotations[consts.NodeDrainAnnotation] = anno + + return node +} + +func createNode(node *v1.Node) { + Expect(k8sClient.Create(goctx.TODO(), node)).Should(Succeed()) +} + +var _ = Describe("Drain Controller", func() { + + BeforeEach(func() { + node1 := createNodeObj("node1", "Drain_Required") + node2 := createNodeObj("node2", "Drain_Required") + createNode(node1) + createNode(node2) + }) + AfterEach(func() { + node1 := createNodeObj("node1", "Drain_Required") + node2 := createNodeObj("node2", "Drain_Required") + err := k8sClient.Delete(goctx.TODO(), node1) + Expect(err).NotTo(HaveOccurred()) + err = k8sClient.Delete(goctx.TODO(), node2) + Expect(err).NotTo(HaveOccurred()) + }) + + Context("Parallel nodes draining", func() { + + It("Should drain one node", func() { + config := &sriovnetworkv1.SriovOperatorConfig{} + err := util.WaitForNamespacedObject(config, k8sClient, testNamespace, "default", interval, timeout) + Expect(err).NotTo(HaveOccurred()) + config.Spec = sriovnetworkv1.SriovOperatorConfigSpec{ + EnableInjector: func() *bool { b := true; return &b }(), + EnableOperatorWebhook: func() *bool { b := true; return &b }(), + MaxParallelNodeConfiguration: 1, + } + updateErr := k8sClient.Update(goctx.TODO(), config) + Expect(updateErr).NotTo(HaveOccurred()) + time.Sleep(3 * time.Second) + + nodeList := &v1.NodeList{} + listErr := k8sClient.List(ctx, nodeList) + Expect(listErr).NotTo(HaveOccurred()) + + drainingNodes := 0 + for _, node := range nodeList.Items { + if utils.NodeHasAnnotation(node, "sriovnetwork.openshift.io/state", "Draining") { + drainingNodes++ + } + } + Expect(drainingNodes).To(Equal(1)) + }) + + It("Should drain two nodes", func() { + config := &sriovnetworkv1.SriovOperatorConfig{} + err := util.WaitForNamespacedObject(config, k8sClient, testNamespace, "default", interval, timeout) + Expect(err).NotTo(HaveOccurred()) + config.Spec = sriovnetworkv1.SriovOperatorConfigSpec{ + EnableInjector: func() *bool { b := true; return &b }(), + EnableOperatorWebhook: func() *bool { b := true; return &b }(), + MaxParallelNodeConfiguration: 2, + } + updateErr := k8sClient.Update(goctx.TODO(), config) + Expect(updateErr).NotTo(HaveOccurred()) + time.Sleep(3 * time.Second) + + nodeList := &v1.NodeList{} + listErr := k8sClient.List(ctx, nodeList) + Expect(listErr).NotTo(HaveOccurred()) + + for _, node := range nodeList.Items { + Expect(utils.NodeHasAnnotation(node, "sriovnetwork.openshift.io/state", "Draining")).To(BeTrue()) + } + }) + + It("Should drain all nodes", func() { + config := &sriovnetworkv1.SriovOperatorConfig{} + err := util.WaitForNamespacedObject(config, k8sClient, testNamespace, "default", interval, timeout) + Expect(err).NotTo(HaveOccurred()) + config.Spec = sriovnetworkv1.SriovOperatorConfigSpec{ + EnableInjector: func() *bool { b := true; return &b }(), + EnableOperatorWebhook: func() *bool { b := true; return &b }(), + MaxParallelNodeConfiguration: 0, + } + updateErr := k8sClient.Update(goctx.TODO(), config) + Expect(updateErr).NotTo(HaveOccurred()) + time.Sleep(3 * time.Second) + + nodeList := &v1.NodeList{} + listErr := k8sClient.List(ctx, nodeList) + Expect(listErr).NotTo(HaveOccurred()) + + for _, node := range nodeList.Items { + Expect(utils.NodeHasAnnotation(node, "sriovnetwork.openshift.io/state", "Draining")).To(BeTrue()) + } + }) + }) +}) diff --git a/controllers/suite_test.go b/controllers/suite_test.go index 9bd666ae7..7ac67f913 100644 --- a/controllers/suite_test.go +++ b/controllers/suite_test.go @@ -127,6 +127,12 @@ var _ = BeforeSuite(func(done Done) { }).SetupWithManager(k8sManager) Expect(err).ToNot(HaveOccurred()) + err = (&DrainReconciler{ + Client: k8sManager.GetClient(), + Scheme: k8sManager.GetScheme(), + }).SetupWithManager(k8sManager) + Expect(err).ToNot(HaveOccurred()) + os.Setenv("RESOURCE_PREFIX", "openshift.io") os.Setenv("NAMESPACE", "openshift-sriov-network-operator") os.Setenv("ENABLE_ADMISSION_CONTROLLER", "true") diff --git a/pkg/daemon/daemon.go b/pkg/daemon/daemon.go index 4c653a17d..bfe25fd27 100644 --- a/pkg/daemon/daemon.go +++ b/pkg/daemon/daemon.go @@ -38,6 +38,7 @@ import ( sriovnetworkv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1" snclientset "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/client/clientset/versioned" sninformer "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/client/informers/externalversions" + consts "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/consts" plugin "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/plugins" "github.com/k8snetworkplumbingwg/sriov-network-operator/pkg/utils" ) @@ -562,7 +563,7 @@ func (dn *Daemon) nodeStateSyncHandler() error { // isNodeDraining: check if the node is draining // both Draining and MCP paused labels will return true func (dn *Daemon) isNodeDraining() bool { - anno, ok := dn.node.Annotations[annoKey] + anno, ok := dn.node.Annotations[consts.NodeDrainAnnotation] if !ok { return false } @@ -689,8 +690,8 @@ func (dn *Daemon) annotateNode(node, value string) error { if newNode.Annotations == nil { newNode.Annotations = map[string]string{} } - if newNode.Annotations[annoKey] != value { - newNode.Annotations[annoKey] = value + if newNode.Annotations[consts.NodeDrainAnnotation] != value { + newNode.Annotations[consts.NodeDrainAnnotation] = value newData, err := json.Marshal(newNode) if err != nil { return err