diff --git a/pkg/controller.v1/tensorflow/tfjob_controller_test.go b/pkg/controller.v1/tensorflow/tfjob_controller_test.go index 4236ccb0c2..1d6f1490e5 100644 --- a/pkg/controller.v1/tensorflow/tfjob_controller_test.go +++ b/pkg/controller.v1/tensorflow/tfjob_controller_test.go @@ -21,13 +21,17 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/uuid" + "k8s.io/utils/pointer" "sigs.k8s.io/controller-runtime/pkg/client" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" tftestutil "github.com/kubeflow/training-operator/pkg/controller.v1/tensorflow/testutil" commonutil "github.com/kubeflow/training-operator/pkg/util" + "github.com/kubeflow/training-operator/pkg/util/testutil" ) var _ = Describe("TFJob controller", func() { @@ -319,4 +323,288 @@ var _ = Describe("TFJob controller", func() { } }) }) + + Context("TFJob with suspend semantics", func() { + const name = "test-job" + var ( + ns *corev1.Namespace + job *kubeflowv1.TFJob + jobKey types.NamespacedName + chiefKey types.NamespacedName + worker0Key types.NamespacedName + ctx = context.Background() + ) + BeforeEach(func() { + ns = &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "tensorflow-test-", + }, + } + Expect(testK8sClient.Create(ctx, ns)).Should(Succeed()) + + // chief=1, worker=1 + job = tftestutil.NewTFJobV2(1, 0, 0, 1, 0) + job.SetName(name) + job.SetNamespace(ns.Name) + jobKey = client.ObjectKeyFromObject(job) + chiefKey = types.NamespacedName{ + Name: fmt.Sprintf("%s-chief-0", name), + Namespace: ns.Name, + } + worker0Key = types.NamespacedName{ + Name: fmt.Sprintf("%s-worker-0", name), + Namespace: ns.Name, + } + }) + AfterEach(func() { + Expect(testK8sClient.Delete(ctx, job)).Should(Succeed()) + Expect(testK8sClient.Delete(ctx, ns)).Should(Succeed()) + }) + + It("Shouldn't create resources if TFJob is suspended", func() { + By("By creating a new TFJob with suspend=true") + job.Spec.RunPolicy.Suspend = pointer.Bool(true) + Expect(testK8sClient.Create(ctx, job)).Should(Succeed()) + + created := &kubeflowv1.TFJob{} + chiefPod := &corev1.Pod{} + workerPod := &corev1.Pod{} + chiefSvc := &corev1.Service{} + workerSvc := &corev1.Service{} + + By("Checking created TFJob") + Eventually(func() bool { + err := testK8sClient.Get(ctx, jobKey, created) + return err == nil + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + By("Checking created TFJob has a nil startTime") + Consistently(func() *metav1.Time { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.StartTime + }, testutil.ConsistentDuration, testutil.Interval).Should(BeNil()) + + By("Checking if the pods and services aren't created") + Consistently(func() bool { + errChiefPod := testK8sClient.Get(ctx, chiefKey, chiefPod) + errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod) + errChiefSvc := testK8sClient.Get(ctx, chiefKey, chiefSvc) + errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc) + return errors.IsNotFound(errChiefPod) && errors.IsNotFound(errWorkerPod) && + errors.IsNotFound(errChiefSvc) && errors.IsNotFound(errWorkerSvc) + }, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue()) + + By("Checking if the TFJob has suspended condition") + Eventually(func() []kubeflowv1.JobCondition { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.Conditions + }, testutil.ConsistentDuration, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{ + { + Type: kubeflowv1.JobCreated, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobCreatedReason), + Message: fmt.Sprintf("TFJob %s is created.", name), + }, + { + Type: kubeflowv1.JobSuspended, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobSuspendedReason), + Message: fmt.Sprintf("TFJob %s is suspended.", name), + }, + }, testutil.IgnoreJobConditionsTimes)) + }) + + It("Should delete resources after TFJob is suspended; Should resume TFJob after TFJob is unsuspended", func() { + By("By creating a new TFJob") + Expect(testK8sClient.Create(ctx, job)).Should(Succeed()) + + created := &kubeflowv1.TFJob{} + chiefPod := &corev1.Pod{} + workerPod := &corev1.Pod{} + chiefSvc := &corev1.Service{} + workerSvc := &corev1.Service{} + + // We'll need to retry getting this newly created TFJob, given that creation may not immediately happen. + By("Checking created TFJob") + Eventually(func() bool { + err := testK8sClient.Get(ctx, jobKey, created) + return err == nil + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + + var startTimeBeforeSuspended *metav1.Time + Eventually(func() *metav1.Time { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + startTimeBeforeSuspended = created.Status.StartTime + return startTimeBeforeSuspended + }, testutil.Timeout, testutil.Interval).ShouldNot(BeNil()) + + By("Checking the created pods and services") + Eventually(func() bool { + errChief := testK8sClient.Get(ctx, chiefKey, chiefPod) + errWorker := testK8sClient.Get(ctx, worker0Key, workerPod) + return errChief == nil && errWorker == nil + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + Eventually(func() bool { + errChief := testK8sClient.Get(ctx, chiefKey, chiefSvc) + errWorker := testK8sClient.Get(ctx, worker0Key, workerSvc) + return errChief == nil && errWorker == nil + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + + By("Updating the pod's phase with Running") + Eventually(func() error { + Expect(testK8sClient.Get(ctx, chiefKey, chiefPod)).Should(Succeed()) + chiefPod.Status.Phase = corev1.PodRunning + return testK8sClient.Status().Update(ctx, chiefPod) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + Eventually(func() error { + Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed()) + workerPod.Status.Phase = corev1.PodRunning + return testK8sClient.Status().Update(ctx, workerPod) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + + By("Checking the TFJob's condition") + Eventually(func() []kubeflowv1.JobCondition { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.Conditions + }, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{ + { + Type: kubeflowv1.JobCreated, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobCreatedReason), + Message: fmt.Sprintf("TFJob %s is created.", name), + }, + { + Type: kubeflowv1.JobRunning, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobRunningReason), + Message: fmt.Sprintf("TFJob %s/%s is running.", ns.Name, name), + }, + }, testutil.IgnoreJobConditionsTimes)) + + By("Updating the TFJob with suspend=true") + Eventually(func() error { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + created.Spec.RunPolicy.Suspend = pointer.Bool(true) + return testK8sClient.Update(ctx, created) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + + By("Checking if the pods and services are removed") + Eventually(func() bool { + errChief := testK8sClient.Get(ctx, chiefKey, chiefPod) + errWorker := testK8sClient.Get(ctx, worker0Key, workerPod) + return errors.IsNotFound(errChief) && errors.IsNotFound(errWorker) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + Eventually(func() bool { + errChief := testK8sClient.Get(ctx, chiefKey, chiefSvc) + errWorker := testK8sClient.Get(ctx, worker0Key, workerSvc) + return errors.IsNotFound(errChief) && errors.IsNotFound(errWorker) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + Consistently(func() bool { + errChiefPod := testK8sClient.Get(ctx, chiefKey, chiefPod) + errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod) + errChiefSvc := testK8sClient.Get(ctx, chiefKey, chiefSvc) + errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc) + return errors.IsNotFound(errChiefPod) && errors.IsNotFound(errWorkerPod) && + errors.IsNotFound(errChiefSvc) && errors.IsNotFound(errWorkerSvc) + }, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue()) + + By("Checking if the TFJob has a suspended condition") + Eventually(func() bool { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.ReplicaStatuses[kubeflowv1.TFJobReplicaTypeChief].Active == 0 && + created.Status.ReplicaStatuses[kubeflowv1.TFJobReplicaTypeWorker].Active == 0 && + created.Status.StartTime.Equal(startTimeBeforeSuspended) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + Consistently(func() bool { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.ReplicaStatuses[kubeflowv1.TFJobReplicaTypeChief].Active == 0 && + created.Status.ReplicaStatuses[kubeflowv1.TFJobReplicaTypeWorker].Active == 0 && + created.Status.StartTime.Equal(startTimeBeforeSuspended) + }, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue()) + Expect(created.Status.Conditions).Should(BeComparableTo([]kubeflowv1.JobCondition{ + { + Type: kubeflowv1.JobCreated, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobCreatedReason), + Message: fmt.Sprintf("TFJob %s is created.", name), + }, + { + Type: kubeflowv1.JobRunning, + Status: corev1.ConditionFalse, + Reason: commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobSuspendedReason), + Message: fmt.Sprintf("TFJob %s is suspended.", name), + }, + { + Type: kubeflowv1.JobSuspended, + Reason: commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobSuspendedReason), + Message: fmt.Sprintf("TFJob %s is suspended.", name), + Status: corev1.ConditionTrue, + }, + }, testutil.IgnoreJobConditionsTimes)) + + By("Unsuspending the TFJob") + Eventually(func() error { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + created.Spec.RunPolicy.Suspend = pointer.Bool(false) + return testK8sClient.Update(ctx, created) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + Eventually(func() *metav1.Time { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.StartTime + }, testutil.Timeout, testutil.Interval).ShouldNot(BeNil()) + + By("Check if the pods and services are created") + Eventually(func() error { + return testK8sClient.Get(ctx, chiefKey, chiefPod) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) + Eventually(func() error { + return testK8sClient.Get(ctx, worker0Key, workerPod) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) + Eventually(func() error { + return testK8sClient.Get(ctx, chiefKey, chiefSvc) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) + Eventually(func() error { + return testK8sClient.Get(ctx, worker0Key, workerSvc) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) + + By("Updating Pod's condition with running") + Eventually(func() error { + Expect(testK8sClient.Get(ctx, chiefKey, chiefPod)).Should(Succeed()) + chiefPod.Status.Phase = corev1.PodRunning + return testK8sClient.Status().Update(ctx, chiefPod) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + Eventually(func() error { + Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed()) + workerPod.Status.Phase = corev1.PodRunning + return testK8sClient.Status().Update(ctx, workerPod) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + + By("Checking if the TFJob has resumed conditions") + Eventually(func() []kubeflowv1.JobCondition { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.Conditions + }, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{ + { + Type: kubeflowv1.JobCreated, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobCreatedReason), + Message: fmt.Sprintf("TFJob %s is created.", name), + }, + { + Type: kubeflowv1.JobSuspended, + Reason: commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobResumedReason), + Message: fmt.Sprintf("TFJob %s is resumed.", name), + Status: corev1.ConditionFalse, + }, + { + Type: kubeflowv1.JobRunning, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobRunningReason), + Message: fmt.Sprintf("TFJob %s/%s is running.", ns.Name, name), + }, + }, testutil.IgnoreJobConditionsTimes)) + + By("Checking if the startTime is updated") + Expect(created.Status.StartTime).ShouldNot(Equal(startTimeBeforeSuspended)) + }) + }) })