diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 1e9fe6115..6b0bc61a7 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -57,12 +57,6 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") } - pytorchTaskExtraArgs := plugins.DistributedPyTorchTrainingTask{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &pytorchTaskExtraArgs) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) - } - podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) @@ -80,6 +74,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx RestartPolicy: commonOp.RestartPolicyNever, } runPolicy := commonOp.RunPolicy{} + var elasticPolicy *kubeflowv1.ElasticPolicy if taskTemplate.TaskTypeVersion == 0 { pytorchTaskExtraArgs := plugins.DistributedPyTorchTrainingTask{} @@ -90,6 +85,11 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx } workerReplica.ReplicaNum = pytorchTaskExtraArgs.GetWorkers() + // Set elastic config + elasticConfig := pytorchTaskExtraArgs.GetElasticConfig() + if elasticConfig != nil { + elasticPolicy = ParseElasticConfig(elasticConfig) + } } else if taskTemplate.TaskTypeVersion == 1 { kfPytorchTaskExtraArgs := kfplugins.DistributedPyTorchTrainingTask{} @@ -134,6 +134,11 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx if kfPytorchTaskExtraArgs.GetRunPolicy() != nil { runPolicy = common.ParseRunPolicy(*kfPytorchTaskExtraArgs.GetRunPolicy()) } + // Set elastic config + elasticConfig := kfPytorchTaskExtraArgs.GetElasticConfig() + if elasticConfig != nil { + elasticPolicy = ParseElasticConfig(elasticConfig) + } } else { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) @@ -164,23 +169,9 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx RunPolicy: runPolicy, } - // Set elastic config - elasticConfig := pytorchTaskExtraArgs.GetElasticConfig() - if elasticConfig != nil { - minReplicas := elasticConfig.GetMinReplicas() - maxReplicas := elasticConfig.GetMaxReplicas() - nProcPerNode := elasticConfig.GetNprocPerNode() - maxRestarts := elasticConfig.GetMaxRestarts() - rdzvBackend := kubeflowv1.RDZVBackend(elasticConfig.GetRdzvBackend()) - elasticPolicy := kubeflowv1.ElasticPolicy{ - MinReplicas: &minReplicas, - MaxReplicas: &maxReplicas, - RDZVBackend: &rdzvBackend, - NProcPerNode: &nProcPerNode, - MaxRestarts: &maxRestarts, - } - jobSpec.ElasticPolicy = &elasticPolicy - // Remove master replica if elastic policy is set + if elasticPolicy != nil { + jobSpec.ElasticPolicy = elasticPolicy + // Remove master replica spec if elastic policy is set delete(jobSpec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeMaster) } @@ -195,6 +186,32 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return job, nil } +// Interface for unified elastic config handling across plugin version v0 and v1. This interface should +// always be aligned with the ElasticConfig defined in flyteidl. +type ElasticConfig interface { + GetMinReplicas() int32 + GetMaxReplicas() int32 + GetNprocPerNode() int32 + GetMaxRestarts() int32 + GetRdzvBackend() string +} + +// To support parsing elastic config from both v0 and v1 of kubeflow pytorch idl +func ParseElasticConfig(elasticConfig ElasticConfig) *kubeflowv1.ElasticPolicy { + minReplicas := elasticConfig.GetMinReplicas() + maxReplicas := elasticConfig.GetMaxReplicas() + nProcPerNode := elasticConfig.GetNprocPerNode() + maxRestarts := elasticConfig.GetMaxRestarts() + rdzvBackend := kubeflowv1.RDZVBackend(elasticConfig.GetRdzvBackend()) + return &kubeflowv1.ElasticPolicy{ + MinReplicas: &minReplicas, + MaxReplicas: &maxReplicas, + RDZVBackend: &rdzvBackend, + NProcPerNode: &nProcPerNode, + MaxRestarts: &maxRestarts, + } +} + // Analyses the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast, // any operations that might take a long time (limits are configured system-wide) should be offloaded to the // background. diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 4a17b7490..27fa4c869 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -519,10 +519,6 @@ func TestBuildResourcePytorchV1(t *testing.T) { }, }, }, - RunPolicy: &kfplugins.RunPolicy{ - CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, - BackoffLimit: 100, - }, } masterResourceRequirements := &corev1.ResourceRequirements{ @@ -567,14 +563,45 @@ func TestBuildResourcePytorchV1(t *testing.T) { assert.Equal(t, commonOp.RestartPolicyAlways, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].RestartPolicy) assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].RestartPolicy) - assert.Equal(t, commonOp.CleanPodPolicyAll, *pytorchJob.Spec.RunPolicy.CleanPodPolicy) - assert.Equal(t, int32(100), *pytorchJob.Spec.RunPolicy.BackoffLimit) + assert.Nil(t, pytorchJob.Spec.RunPolicy.CleanPodPolicy) + assert.Nil(t, pytorchJob.Spec.RunPolicy.BackoffLimit) assert.Nil(t, pytorchJob.Spec.RunPolicy.TTLSecondsAfterFinished) assert.Nil(t, pytorchJob.Spec.RunPolicy.ActiveDeadlineSeconds) assert.Nil(t, pytorchJob.Spec.ElasticPolicy) } +func TestBuildResourcePytorchV1WithRunPolicy(t *testing.T) { + taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 100, + }, + RunPolicy: &kfplugins.RunPolicy{ + CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, + BackoffLimit: 100, + ActiveDeadlineSeconds: 1000, + TtlSecondsAfterFinished: 10000, + }, + } + pytorchResourceHandler := pytorchOperatorResourceHandler{} + + taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) + taskTemplate.TaskTypeVersion = 1 + + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, res) + + pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) + assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) + assert.Nil(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) + assert.Equal(t, commonOp.CleanPodPolicyAll, *pytorchJob.Spec.RunPolicy.CleanPodPolicy) + assert.Equal(t, int32(100), *pytorchJob.Spec.RunPolicy.BackoffLimit) + assert.Equal(t, int64(1000), *pytorchJob.Spec.RunPolicy.ActiveDeadlineSeconds) + assert.Equal(t, int32(10000), *pytorchJob.Spec.RunPolicy.TTLSecondsAfterFinished) +} + func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ @@ -638,3 +665,63 @@ func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { assert.Nil(t, pytorchJob.Spec.ElasticPolicy) } + +func TestBuildResourcePytorchV1WithElastic(t *testing.T) { + taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 2, + }, + ElasticConfig: &kfplugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"}, + } + taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) + taskTemplate.TaskTypeVersion = 1 + + pytorchResourceHandler := pytorchOperatorResourceHandler{} + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + pytorchJob, ok := resource.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) + assert.Equal(t, int32(2), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) + assert.NotNil(t, pytorchJob.Spec.ElasticPolicy) + assert.Equal(t, int32(1), *pytorchJob.Spec.ElasticPolicy.MinReplicas) + assert.Equal(t, int32(2), *pytorchJob.Spec.ElasticPolicy.MaxReplicas) + assert.Equal(t, int32(4), *pytorchJob.Spec.ElasticPolicy.NProcPerNode) + assert.Equal(t, kubeflowv1.RDZVBackend("c10d"), *pytorchJob.Spec.ElasticPolicy.RDZVBackend) + + assert.Equal(t, 1, len(pytorchJob.Spec.PyTorchReplicaSpecs)) + assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker) + + var hasContainerWithDefaultPytorchName = false + + for _, container := range pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers { + if container.Name == kubeflowv1.PytorchJobDefaultContainerName { + hasContainerWithDefaultPytorchName = true + } + } + + assert.True(t, hasContainerWithDefaultPytorchName) +} + +func TestBuildResourcePytorchV1WithZeroWorker(t *testing.T) { + taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 0, + }, + } + pytorchResourceHandler := pytorchOperatorResourceHandler{} + taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) + taskTemplate.TaskTypeVersion = 1 + _, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) + assert.Error(t, err) +} + +func TestParasElasticConfig(t *testing.T) { + elasticConfig := plugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"} + elasticPolicy := ParseElasticConfig(&elasticConfig) + assert.Equal(t, int32(1), *elasticPolicy.MinReplicas) + assert.Equal(t, int32(2), *elasticPolicy.MaxReplicas) + assert.Equal(t, int32(4), *elasticPolicy.NProcPerNode) + assert.Equal(t, kubeflowv1.RDZVBackend("c10d"), *elasticPolicy.RDZVBackend) +}