diff --git a/pkg/cache/application.go b/pkg/cache/application.go index 8438957ba..694e57137 100644 --- a/pkg/cache/application.go +++ b/pkg/cache/application.go @@ -48,7 +48,7 @@ type Application struct { groups []string taskMap map[string]*Task tags map[string]string - taskGroups []TaskGroup + taskGroups map[string]*TaskGroup taskGroupsDefinition string schedulingParamsDefinition string placeholderOwnerReferences []metav1.OwnerReference @@ -71,6 +71,7 @@ func (app *Application) String() string { func NewApplication(appID, queueName, user string, groups []string, tags map[string]string, scheduler api.SchedulerAPI) *Application { taskMap := make(map[string]*Task) + taskGroups := make(map[string]*TaskGroup) app := &Application{ applicationID: appID, queue: queueName, @@ -80,7 +81,7 @@ func NewApplication(appID, queueName, user string, groups []string, tags map[str taskMap: taskMap, tags: tags, sm: newAppState(), - taskGroups: make([]TaskGroup, 0), + taskGroups: taskGroups, lock: &locking.RWMutex{}, schedulerAPI: scheduler, placeholderTimeoutInSec: 0, @@ -166,9 +167,16 @@ func (app *Application) GetSchedulingParamsDefinition() string { func (app *Application) setTaskGroups(taskGroups []TaskGroup) { app.lock.Lock() defer app.lock.Unlock() - app.taskGroups = taskGroups - for _, taskGroup := range app.taskGroups { - app.placeholderAsk = common.Add(app.placeholderAsk, common.GetTGResource(taskGroup.MinResource, int64(taskGroup.MinMember))) + for _, tg := range taskGroups { + taskGroup := tg + if _, exists := app.taskGroups[taskGroup.Name]; exists { + log.Log(log.ShimCacheApplication).Warn("duplicate task-group within the task-groups", + zap.String("appID", app.applicationID), + zap.String("groupName", taskGroup.Name)) + } else { + app.taskGroups[taskGroup.Name] = &taskGroup + app.placeholderAsk = common.Add(app.placeholderAsk, common.GetTGResource(taskGroup.MinResource, int64(taskGroup.MinMember))) + } } } @@ -181,7 +189,14 @@ func (app *Application) getPlaceholderAsk() *si.Resource { func (app *Application) getTaskGroups() []TaskGroup { app.lock.RLock() defer app.lock.RUnlock() - return app.taskGroups + + taskGroups := make([]TaskGroup, 0) + if len(app.taskGroups) > 0 { + for _, taskGroup := range app.taskGroups { + taskGroups = append(taskGroups, *taskGroup) + } + } + return taskGroups } func (app *Application) setPlaceholderOwnerReferences(ref []metav1.OwnerReference) { diff --git a/pkg/cache/placeholder_test.go b/pkg/cache/placeholder_test.go index cbabd2891..f02267d20 100644 --- a/pkg/cache/placeholder_test.go +++ b/pkg/cache/placeholder_test.go @@ -119,9 +119,10 @@ func TestNewPlaceholder(t *testing.T) { assert.Equal(t, app.placeholderAsk.Resources[siCommon.Memory].Value, int64(10*1024*1000*1000)) assert.Equal(t, app.placeholderAsk.Resources["pods"].Value, int64(10)) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, holder.appID, appID) - assert.Equal(t, holder.taskGroupName, app.taskGroups[0].Name) + assert.Equal(t, holder.taskGroupName, tgs[0].Name) assert.Equal(t, holder.pod.Spec.SchedulerName, constants.SchedulerName) assert.Equal(t, holder.pod.Name, "ph-name") assert.Equal(t, holder.pod.Namespace, namespace) @@ -132,7 +133,7 @@ func TestNewPlaceholder(t *testing.T) { "labelKey1": "labelKeyValue1", }) assert.Equal(t, len(holder.pod.Annotations), 7, "unexpected number of annotations") - assert.Equal(t, holder.pod.Annotations[constants.AnnotationTaskGroupName], app.taskGroups[0].Name) + assert.Equal(t, holder.pod.Annotations[constants.AnnotationTaskGroupName], tgs[0].Name) assert.Equal(t, holder.pod.Annotations[constants.AnnotationPlaceholderFlag], constants.True) assert.Equal(t, holder.pod.Annotations["annotationKey0"], "annotationValue0") assert.Equal(t, holder.pod.Annotations["annotationKey1"], "annotationValue1") @@ -163,7 +164,8 @@ func TestNewPlaceholderWithNodeSelectors(t *testing.T) { "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, len(holder.pod.Spec.NodeSelector), 2) assert.Equal(t, holder.pod.Spec.NodeSelector["nodeType"], "test") assert.Equal(t, holder.pod.Spec.NodeSelector["nodeState"], "healthy") @@ -178,7 +180,8 @@ func TestNewPlaceholderWithTolerations(t *testing.T) { "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, len(holder.pod.Spec.Tolerations), 1) tlr := holder.pod.Spec.Tolerations[0] assert.Equal(t, tlr.Key, "key1") @@ -196,7 +199,8 @@ func TestNewPlaceholderWithAffinity(t *testing.T) { "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, len(holder.pod.Spec.Affinity.PodAffinity.RequiredDuringSchedulingIgnoredDuringExecution), 1) term := holder.pod.Spec.Affinity.PodAffinity.RequiredDuringSchedulingIgnoredDuringExecution assert.Equal(t, term[0].TopologyKey, "topologyKey") @@ -215,14 +219,16 @@ func TestNewPlaceholderTaskGroupsDefinition(t *testing.T) { app := NewApplication(appID, queue, "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, "", holder.pod.Annotations[constants.AnnotationTaskGroups]) app = NewApplication(appID, queue, "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) app.setTaskGroupsDefinition("taskGroupsDef") - holder = newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs = app.getTaskGroups() + holder = newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, "taskGroupsDef", holder.pod.Annotations[constants.AnnotationTaskGroups]) var priority *int32 assert.Equal(t, priority, holder.pod.Spec.Priority) @@ -234,7 +240,9 @@ func TestNewPlaceholderExtendedResources(t *testing.T) { app := NewApplication(appID, queue, "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, len(holder.pod.Spec.Containers[0].Resources.Requests), 5, "expected requests not found") assert.Equal(t, len(holder.pod.Spec.Containers[0].Resources.Limits), 5, "expected limits not found") assert.Equal(t, holder.pod.Spec.Containers[0].Resources.Limits[gpu], holder.pod.Spec.Containers[0].Resources.Requests[gpu], "gpu: expected same value for request and limit") @@ -271,7 +279,8 @@ func TestNewPlaceholderWithPriorityClassName(t *testing.T) { app.taskMap[taskID1] = task1 app.setOriginatingTask(task1) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, len(holder.pod.Spec.Containers[0].Resources.Requests), 5, "expected requests not found") assert.Equal(t, len(holder.pod.Spec.Containers[0].Resources.Limits), 5, "expected limits not found") assert.Equal(t, holder.pod.Spec.Containers[0].Resources.Limits[gpu], holder.pod.Spec.Containers[0].Resources.Requests[gpu], "gpu: expected same value for request and limit") @@ -287,7 +296,8 @@ func TestNewPlaceholderWithTopologySpreadConstraints(t *testing.T) { "bob", testGroups, map[string]string{constants.AppTagNamespace: namespace}, mockedSchedulerAPI) app.setTaskGroups(taskGroups) - holder := newPlaceholder("ph-name", app, app.taskGroups[0]) + tgs := app.getTaskGroups() + holder := newPlaceholder("ph-name", app, tgs[0]) assert.Equal(t, len(holder.pod.Spec.TopologySpreadConstraints), 1) assert.Equal(t, holder.pod.Spec.TopologySpreadConstraints[0].MaxSkew, int32(1)) assert.Equal(t, holder.pod.Spec.TopologySpreadConstraints[0].TopologyKey, v1.LabelTopologyZone)