diff --git a/pkg/cache/context.go b/pkg/cache/context.go index aaea690aa..b4dca41e7 100644 --- a/pkg/cache/context.go +++ b/pkg/cache/context.go @@ -72,7 +72,7 @@ type Context struct { pluginMode bool // true if we are configured as a scheduler plugin namespace string // yunikorn namespace configMaps []*v1.ConfigMap // cached yunikorn configmaps - lock *locking.RWMutex // lock + lock *locking.RWMutex // lock - used not only for context data but also to ensure that multiple event types are not executed concurrently txnID atomic.Uint64 // transaction ID counter klogger klog.Logger } @@ -166,6 +166,8 @@ func (ctx *Context) addNode(obj interface{}) { } func (ctx *Context) updateNode(_, obj interface{}) { + ctx.lock.Lock() + defer ctx.lock.Unlock() node, err := convertToNode(obj) if err != nil { log.Log(log.ShimContext).Error("node conversion failed", zap.Error(err)) @@ -227,6 +229,8 @@ func (ctx *Context) updateNodeInternal(node *v1.Node, register bool) { } func (ctx *Context) deleteNode(obj interface{}) { + ctx.lock.Lock() + defer ctx.lock.Unlock() var node *v1.Node switch t := obj.(type) { case *v1.Node: @@ -246,6 +250,8 @@ func (ctx *Context) deleteNode(obj interface{}) { } func (ctx *Context) addNodesWithoutRegistering(nodes []*v1.Node) { + ctx.lock.Lock() + defer ctx.lock.Unlock() for _, node := range nodes { ctx.updateNodeInternal(node, false) } @@ -281,6 +287,8 @@ func (ctx *Context) AddPod(obj interface{}) { } func (ctx *Context) UpdatePod(_, newObj interface{}) { + ctx.lock.Lock() + defer ctx.lock.Unlock() pod, err := utils.Convert2Pod(newObj) if err != nil { log.Log(log.ShimContext).Error("failed to update pod", zap.Error(err)) @@ -328,7 +336,7 @@ func (ctx *Context) ensureAppAndTaskCreated(pod *v1.Pod, app *Application) { zap.String("name", pod.Name)) return } - app = ctx.AddApplication(&AddApplicationRequest{ + app = ctx.addApplication(&AddApplicationRequest{ Metadata: appMeta, }) } @@ -432,8 +440,10 @@ func (ctx *Context) DeletePod(obj interface{}) { } func (ctx *Context) deleteYuniKornPod(pod *v1.Pod) { + ctx.lock.Lock() + defer ctx.lock.Unlock() if taskMeta, ok := getTaskMetadata(pod); ok { - ctx.notifyTaskComplete(ctx.GetApplication(taskMeta.ApplicationID), taskMeta.TaskID) + ctx.notifyTaskComplete(ctx.getApplication(taskMeta.ApplicationID), taskMeta.TaskID) } log.Log(log.ShimContext).Debug("removing pod from cache", zap.String("podName", pod.Name)) @@ -441,6 +451,8 @@ func (ctx *Context) deleteYuniKornPod(pod *v1.Pod) { } func (ctx *Context) deleteForeignPod(pod *v1.Pod) { + ctx.lock.Lock() + defer ctx.lock.Unlock() oldPod := ctx.schedulerCache.GetPod(string(pod.UID)) if oldPod == nil { // if pod is not in scheduler cache, no node updates are needed @@ -571,6 +583,8 @@ func (ctx *Context) addPriorityClass(obj interface{}) { } func (ctx *Context) updatePriorityClass(_, newObj interface{}) { + ctx.lock.Lock() + defer ctx.lock.Unlock() if priorityClass := utils.Convert2PriorityClass(newObj); priorityClass != nil { ctx.updatePriorityClassInternal(priorityClass) } @@ -581,6 +595,8 @@ func (ctx *Context) updatePriorityClassInternal(priorityClass *schedulingv1.Prio } func (ctx *Context) deletePriorityClass(obj interface{}) { + ctx.lock.Lock() + defer ctx.lock.Unlock() log.Log(log.ShimContext).Debug("priorityClass deleted") var priorityClass *schedulingv1.PriorityClass switch t := obj.(type) { @@ -646,6 +662,8 @@ func (ctx *Context) EventsToRegister(queueingHintFn framework.QueueingHintFn) [] // IsPodFitNode evaluates given predicates based on current context func (ctx *Context) IsPodFitNode(name, node string, allocate bool) error { + ctx.lock.RLock() + defer ctx.lock.RUnlock() pod := ctx.schedulerCache.GetPod(name) if pod == nil { return ErrorPodNotFound @@ -666,6 +684,8 @@ func (ctx *Context) IsPodFitNode(name, node string, allocate bool) error { } func (ctx *Context) IsPodFitNodeViaPreemption(name, node string, allocations []string, startIndex int) (int, bool) { + ctx.lock.RLock() + defer ctx.lock.RUnlock() if pod := ctx.schedulerCache.GetPod(name); pod != nil { // if pod exists in cache, try to run predicates if targetNode := ctx.schedulerCache.GetNode(node); targetNode != nil { @@ -774,6 +794,8 @@ func (ctx *Context) bindPodVolumes(pod *v1.Pod) error { // this way, the core can make allocation decisions with consideration of // other assumed pods before they are actually bound to the node (bound is slow). func (ctx *Context) AssumePod(name, node string) error { + ctx.lock.Lock() + defer ctx.lock.Unlock() if pod := ctx.schedulerCache.GetPod(name); pod != nil { // when add assumed pod, we make a copy of the pod to avoid // modifying its original reference. otherwise, it may have @@ -833,6 +855,8 @@ func (ctx *Context) AssumePod(name, node string) error { // forget pod must be called when a pod is assumed to be running on a node, // but then for some reason it is failed to bind or released. func (ctx *Context) ForgetPod(name string) { + ctx.lock.Lock() + defer ctx.lock.Unlock() if pod := ctx.schedulerCache.GetPod(name); pod != nil { log.Log(log.ShimContext).Debug("forget pod", zap.String("pod", pod.Name)) ctx.schedulerCache.ForgetPod(pod) @@ -949,6 +973,10 @@ func (ctx *Context) AddApplication(request *AddApplicationRequest) *Application ctx.lock.Lock() defer ctx.lock.Unlock() + return ctx.addApplication(request) +} + +func (ctx *Context) addApplication(request *AddApplicationRequest) *Application { log.Log(log.ShimContext).Debug("AddApplication", zap.Any("Request", request)) if app := ctx.getApplication(request.Metadata.ApplicationID); app != nil { return app @@ -1026,6 +1054,8 @@ func (ctx *Context) RemoveApplication(appID string) { // this implements ApplicationManagementProtocol func (ctx *Context) AddTask(request *AddTaskRequest) *Task { + ctx.lock.Lock() + defer ctx.lock.Unlock() return ctx.addTask(request) } @@ -1074,8 +1104,8 @@ func (ctx *Context) addTask(request *AddTaskRequest) *Task { } func (ctx *Context) RemoveTask(appID, taskID string) { - ctx.lock.RLock() - defer ctx.lock.RUnlock() + ctx.lock.Lock() + defer ctx.lock.Unlock() app, ok := ctx.applications[appID] if !ok { log.Log(log.ShimContext).Debug("Attempted to remove task from non-existent application", zap.String("appID", appID)) @@ -1085,7 +1115,9 @@ func (ctx *Context) RemoveTask(appID, taskID string) { } func (ctx *Context) getTask(appID string, taskID string) *Task { - app := ctx.GetApplication(appID) + ctx.lock.RLock() + defer ctx.lock.RUnlock() + app := ctx.getApplication(appID) if app == nil { log.Log(log.ShimContext).Debug("application is not found in the context", zap.String("appID", appID)) @@ -1354,7 +1386,7 @@ func (ctx *Context) InitializeState() error { log.Log(log.ShimContext).Error("failed to load nodes", zap.Error(err)) return err } - acceptedNodes, err := ctx.registerNodes(nodes) + acceptedNodes, err := ctx.RegisterNodes(nodes) if err != nil { log.Log(log.ShimContext).Error("failed to register nodes", zap.Error(err)) return err @@ -1474,11 +1506,17 @@ func (ctx *Context) registerNode(node *v1.Node) error { return nil } +func (ctx *Context) RegisterNodes(nodes []*v1.Node) ([]*v1.Node, error) { + ctx.lock.Lock() + defer ctx.lock.Unlock() + return ctx.registerNodes(nodes) +} + +// registerNodes registers the nodes to the scheduler core. +// This method must be called while holding the Context write lock. func (ctx *Context) registerNodes(nodes []*v1.Node) ([]*v1.Node, error) { nodesToRegister := make([]*si.NodeInfo, 0) pendingNodes := make(map[string]*v1.Node) - acceptedNodes := make([]*v1.Node, 0) - rejectedNodes := make([]*v1.Node, 0) // Generate a NodeInfo object for each node and add to the registration request for _, node := range nodes { @@ -1497,12 +1535,34 @@ func (ctx *Context) registerNodes(nodes []*v1.Node) ([]*v1.Node, error) { pendingNodes[node.Name] = node } - var wg sync.WaitGroup + acceptedNodes, rejectedNodes, err := ctx.registerNodesInternal(nodesToRegister, pendingNodes) + if err != nil { + log.Log(log.ShimContext).Error("Failed to register nodes", zap.Error(err)) + return nil, err + } + + for _, node := range acceptedNodes { + // post a successful event to the node + events.GetRecorder().Eventf(node.DeepCopy(), nil, v1.EventTypeNormal, "NodeAccepted", "NodeAccepted", + fmt.Sprintf("node %s is accepted by the scheduler", node.Name)) + } + for _, node := range rejectedNodes { + // post a failure event to the node + events.GetRecorder().Eventf(node.DeepCopy(), nil, v1.EventTypeWarning, "NodeRejected", "NodeRejected", + fmt.Sprintf("node %s is rejected by the scheduler", node.Name)) + } + return acceptedNodes, nil +} + +func (ctx *Context) registerNodesInternal(nodesToRegister []*si.NodeInfo, pendingNodes map[string]*v1.Node) ([]*v1.Node, []*v1.Node, error) { + acceptedNodes := make([]*v1.Node, 0) + rejectedNodes := make([]*v1.Node, 0) + + var wg sync.WaitGroup // initialize wait group with the number of responses we expect wg.Add(len(pendingNodes)) - // register with the dispatcher so that we can track our response handlerID := fmt.Sprintf("%s-%d", registerNodeContextHandler, ctx.txnID.Add(1)) dispatcher.RegisterEventHandler(handlerID, dispatcher.EventTypeNode, func(event interface{}) { nodeEvent, ok := event.(CachedSchedulerNodeEvent) @@ -1534,24 +1594,17 @@ func (ctx *Context) registerNodes(nodes []*v1.Node) ([]*v1.Node, error) { RmID: schedulerconf.GetSchedulerConf().ClusterID, }); err != nil { log.Log(log.ShimContext).Error("Failed to register nodes", zap.Error(err)) - return nil, err + return nil, nil, err } + // write lock must always be held at this point, releasing it while waiting to avoid any potential deadlocks + ctx.lock.Unlock() + defer ctx.lock.Lock() + // wait for all responses to accumulate wg.Wait() - for _, node := range acceptedNodes { - // post a successful event to the node - events.GetRecorder().Eventf(node.DeepCopy(), nil, v1.EventTypeNormal, "NodeAccepted", "NodeAccepted", - fmt.Sprintf("node %s is accepted by the scheduler", node.Name)) - } - for _, node := range rejectedNodes { - // post a failure event to the node - events.GetRecorder().Eventf(node.DeepCopy(), nil, v1.EventTypeWarning, "NodeRejected", "NodeRejected", - fmt.Sprintf("node %s is rejected by the scheduler", node.Name)) - } - - return acceptedNodes, nil + return acceptedNodes, rejectedNodes, nil } func (ctx *Context) decommissionNode(node *v1.Node) error {