Skip to content

Commit

Permalink
Fix unit tests in runtime package
Browse files Browse the repository at this point in the history
Signed-off-by: Andrey Velichkevich <[email protected]>
  • Loading branch information
andreyvelich committed Oct 25, 2024
1 parent 37d5de0 commit 86e606a
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 151 deletions.
9 changes: 3 additions & 6 deletions pkg/runtime.v2/core/clustertrainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"

Expand Down Expand Up @@ -72,18 +70,17 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
Suspend(true).
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
ContainerImage(ptr.To("test:trainjob")).
JobCompletionMode(batchv1.IndexedCompletion).
ContainerImage("test:trainjob").
ResourceRequests(0, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
}).
ResourceRequests(1, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
}).
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainJobKind), "test-job", "uid").
Obj(),
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainJobKind), "test-job", "uid").
MinMember(40).
SchedulingTimeout(120).
MinResources(corev1.ResourceList{
Expand Down
1 change: 1 addition & 0 deletions pkg/runtime.v2/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ func (r *TrainingRuntime) buildObjects(
runtime.WithPodGroupPolicy(podGroupPolicy),
}
for _, rJob := range jobSetTemplateSpec.Spec.ReplicatedJobs {
// By default every ReplicatedJob has only 1 replica.
opts = append(opts, runtime.WithPodSpecReplicas(rJob.Name, 1, rJob.Template.Spec.Template.Spec))
}

Expand Down
9 changes: 3 additions & 6 deletions pkg/runtime.v2/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"

Expand Down Expand Up @@ -79,18 +77,17 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Label("conflictLabel", "override").
Annotation("conflictAnnotation", "override").
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
ContainerImage(ptr.To("test:trainjob")).
JobCompletionMode(batchv1.IndexedCompletion).
ContainerImage("test:trainjob").
ResourceRequests(0, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
}).
ResourceRequests(1, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
}).
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainJobKind), "test-job", "uid").
Obj(),
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainJobKind), "test-job", "uid").
MinMember(40).
SchedulingTimeout(120).
MinResources(corev1.ResourceList{
Expand Down
103 changes: 70 additions & 33 deletions pkg/runtime.v2/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/meta"
"k8s.io/apimachinery/pkg/api/resource"
Expand All @@ -35,6 +34,7 @@ import (
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"

kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
"github.com/kubeflow/training-operator/pkg/constants"
runtime "github.com/kubeflow/training-operator/pkg/runtime.v2"
"github.com/kubeflow/training-operator/pkg/runtime.v2/framework"
fwkplugins "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins"
Expand Down Expand Up @@ -150,10 +150,11 @@ func TestRunEnforceMLPolicyPlugins(t *testing.T) {
cases := map[string]struct {
registry fwkplugins.Registry
runtimeInfo *runtime.Info
trainJob *kubeflowv2.TrainJob
wantRuntimeInfo *runtime.Info
wantError error
}{
"plainml MLPolicy is applied to runtime.Info": {
"plainml MLPolicy is applied to runtime.Info, TrainJob doesn't have numNodes": {
registry: fwkplugins.NewRegistry(),
runtimeInfo: &runtime.Info{
Policy: runtime.Policy{
Expand All @@ -162,19 +163,60 @@ func TestRunEnforceMLPolicyPlugins(t *testing.T) {
},
},
TotalRequests: map[string]runtime.TotalResourceRequest{
"Coordinator": {Replicas: 1},
"Worker": {Replicas: 10},
constants.JobInitializer: {Replicas: 1},
constants.JobTrainerNode: {Replicas: 10},
},
},
trainJob: &kubeflowv2.TrainJob{
Spec: kubeflowv2.TrainJobSpec{},
},
wantRuntimeInfo: &runtime.Info{
Policy: runtime.Policy{
MLPolicy: &kubeflowv2.MLPolicy{
NumNodes: ptr.To[int32](100),
},
},
Trainer: runtime.Trainer{
NumNodes: ptr.To[int32](100),
},
TotalRequests: map[string]runtime.TotalResourceRequest{
"Coordinator": {Replicas: 100},
"Worker": {Replicas: 100},
constants.JobInitializer: {Replicas: 1},
constants.JobTrainerNode: {Replicas: 100},
},
},
},
"plainml MLPolicy is applied to runtime.Info, TrainJob has numNodes": {
registry: fwkplugins.NewRegistry(),
runtimeInfo: &runtime.Info{
Policy: runtime.Policy{
MLPolicy: &kubeflowv2.MLPolicy{
NumNodes: ptr.To[int32](100),
},
},
TotalRequests: map[string]runtime.TotalResourceRequest{
constants.JobInitializer: {Replicas: 1},
constants.JobTrainerNode: {Replicas: 10},
},
},
trainJob: &kubeflowv2.TrainJob{
Spec: kubeflowv2.TrainJobSpec{
Trainer: &kubeflowv2.Trainer{
NumNodes: ptr.To[int32](30),
},
},
},
wantRuntimeInfo: &runtime.Info{
Policy: runtime.Policy{
MLPolicy: &kubeflowv2.MLPolicy{
NumNodes: ptr.To[int32](100),
},
},
Trainer: runtime.Trainer{
NumNodes: ptr.To[int32](30),
},
TotalRequests: map[string]runtime.TotalResourceRequest{
constants.JobInitializer: {Replicas: 1},
constants.JobTrainerNode: {Replicas: 30},
},
},
},
Expand All @@ -186,8 +228,8 @@ func TestRunEnforceMLPolicyPlugins(t *testing.T) {
},
},
TotalRequests: map[string]runtime.TotalResourceRequest{
"Coordinator": {Replicas: 1},
"Worker": {Replicas: 10},
constants.JobInitializer: {Replicas: 1},
constants.JobTrainerNode: {Replicas: 10},
},
},
wantRuntimeInfo: &runtime.Info{
Expand All @@ -197,8 +239,8 @@ func TestRunEnforceMLPolicyPlugins(t *testing.T) {
},
},
TotalRequests: map[string]runtime.TotalResourceRequest{
"Coordinator": {Replicas: 1},
"Worker": {Replicas: 10},
constants.JobInitializer: {Replicas: 1},
constants.JobTrainerNode: {Replicas: 10},
},
},
},
Expand All @@ -213,7 +255,7 @@ func TestRunEnforceMLPolicyPlugins(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = fwk.RunEnforceMLPolicyPlugins(tc.runtimeInfo)
err = fwk.RunEnforceMLPolicyPlugins(tc.runtimeInfo, tc.trainJob)
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got): %s", diff)
}
Expand Down Expand Up @@ -274,7 +316,7 @@ func TestRunEnforcePodGroupPolicyPlugins(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = fwk.RunEnforcePodGroupPolicyPlugins(tc.trainJob, tc.runtimeInfo)
err = fwk.RunEnforcePodGroupPolicyPlugins(tc.runtimeInfo, tc.trainJob)
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got): %s", diff)
}
Expand Down Expand Up @@ -337,17 +379,17 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
})
jobSetWithPropagatedTrainJobParams := jobSetBase.
Clone().
JobCompletionMode(batchv1.IndexedCompletion).
ContainerImage(ptr.To("foo:bar")).
ContainerImage("foo:bar").
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid")

cases := map[string]struct {
runtimeInfo *runtime.Info
trainJob *kubeflowv2.TrainJob
registry fwkplugins.Registry
wantError error
wantRuntimeInfo *runtime.Info
wantObjs []client.Object
runtimeInfo *runtime.Info
trainJob *kubeflowv2.TrainJob
runtimeJobTemplateSpec interface{}
registry fwkplugins.Registry
wantError error
wantRuntimeInfo *runtime.Info
wantObjs []client.Object
}{
"coscheduling and jobset are performed": {
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
Expand All @@ -359,9 +401,6 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
).
Obj(),
runtimeInfo: &runtime.Info{
Obj: jobSetBase.
Clone().
Obj(),
Policy: runtime.Policy{
MLPolicy: &kubeflowv2.MLPolicy{
NumNodes: ptr.To[int32](10),
Expand All @@ -375,14 +414,14 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
},
},
TotalRequests: map[string]runtime.TotalResourceRequest{
"Coordinator": {
constants.JobInitializer: {
Replicas: 1,
PodRequests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("4Gi"),
},
},
"Worker": {
constants.JobTrainerNode: {
Replicas: 1,
PodRequests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
Expand All @@ -391,7 +430,8 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
},
},
},
registry: fwkplugins.NewRegistry(),
runtimeJobTemplateSpec: jobSetBase.Spec,
registry: fwkplugins.NewRegistry(),
wantObjs: []client.Object{
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
SchedulingTimeout(300).
Expand All @@ -407,9 +447,6 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
Obj(),
},
wantRuntimeInfo: &runtime.Info{
Obj: jobSetWithPropagatedTrainJobParams.
Clone().
Obj(),
Policy: runtime.Policy{
MLPolicy: &kubeflowv2.MLPolicy{
NumNodes: ptr.To[int32](10),
Expand All @@ -423,14 +460,14 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
},
},
TotalRequests: map[string]runtime.TotalResourceRequest{
"Coordinator": {
constants.JobInitializer: {
Replicas: 10,
PodRequests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("4Gi"),
},
},
"Worker": {
constants.JobTrainerNode: {
Replicas: 10,
PodRequests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
Expand Down Expand Up @@ -458,10 +495,10 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if err = fwk.RunEnforceMLPolicyPlugins(tc.runtimeInfo); err != nil {
if err = fwk.RunEnforceMLPolicyPlugins(tc.runtimeInfo, tc.trainJob); err != nil {
t.Fatal(err)
}
objs, err := fwk.RunComponentBuilderPlugins(ctx, tc.runtimeInfo, tc.trainJob)
objs, err := fwk.RunComponentBuilderPlugins(ctx, tc.runtimeInfo, tc.trainJob, tc.runtimeJobTemplateSpec)
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected errors (-want,+got):\n%s", diff)
}
Expand Down
1 change: 1 addition & 0 deletions pkg/runtime.v2/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type Info struct {
}

type Trainer struct {
// TODO (andreyvelich): Add more parameters.
NumNodes *int32
}

Expand Down
Loading

0 comments on commit 86e606a

Please sign in to comment.