Skip to content

Commit

Permalink
[YUNIKORN-2910] Fix data corruption due to insufficient shim context …
Browse files Browse the repository at this point in the history
…locking (#924)

Restore context locking that was removed as part of YUNIKORN-2629. The
locks are necessary to prevent logical data corruption due to concurrent
processing of both pod and node events.

Closes: #924
  • Loading branch information
craigcondit committed Oct 10, 2024
1 parent 92a17ad commit e9b05eb
Showing 1 changed file with 77 additions and 24 deletions.
101 changes: 77 additions & 24 deletions pkg/cache/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type Context struct {
pluginMode bool // true if we are configured as a scheduler plugin
namespace string // yunikorn namespace
configMaps []*v1.ConfigMap // cached yunikorn configmaps
lock *locking.RWMutex // lock
lock *locking.RWMutex // lock - used not only for context data but also to ensure that multiple event types are not executed concurrently
txnID atomic.Uint64 // transaction ID counter
klogger klog.Logger
}
Expand Down Expand Up @@ -166,6 +166,8 @@ func (ctx *Context) addNode(obj interface{}) {
}

func (ctx *Context) updateNode(_, obj interface{}) {
ctx.lock.Lock()
defer ctx.lock.Unlock()
node, err := convertToNode(obj)
if err != nil {
log.Log(log.ShimContext).Error("node conversion failed", zap.Error(err))
Expand Down Expand Up @@ -227,6 +229,8 @@ func (ctx *Context) updateNodeInternal(node *v1.Node, register bool) {
}

func (ctx *Context) deleteNode(obj interface{}) {
ctx.lock.Lock()
defer ctx.lock.Unlock()
var node *v1.Node
switch t := obj.(type) {
case *v1.Node:
Expand All @@ -246,6 +250,8 @@ func (ctx *Context) deleteNode(obj interface{}) {
}

func (ctx *Context) addNodesWithoutRegistering(nodes []*v1.Node) {
ctx.lock.Lock()
defer ctx.lock.Unlock()
for _, node := range nodes {
ctx.updateNodeInternal(node, false)
}
Expand Down Expand Up @@ -281,6 +287,8 @@ func (ctx *Context) AddPod(obj interface{}) {
}

func (ctx *Context) UpdatePod(_, newObj interface{}) {
ctx.lock.Lock()
defer ctx.lock.Unlock()
pod, err := utils.Convert2Pod(newObj)
if err != nil {
log.Log(log.ShimContext).Error("failed to update pod", zap.Error(err))
Expand Down Expand Up @@ -328,7 +336,7 @@ func (ctx *Context) ensureAppAndTaskCreated(pod *v1.Pod, app *Application) {
zap.String("name", pod.Name))
return
}
app = ctx.AddApplication(&AddApplicationRequest{
app = ctx.addApplication(&AddApplicationRequest{
Metadata: appMeta,
})
}
Expand Down Expand Up @@ -432,15 +440,19 @@ func (ctx *Context) DeletePod(obj interface{}) {
}

func (ctx *Context) deleteYuniKornPod(pod *v1.Pod) {
ctx.lock.Lock()
defer ctx.lock.Unlock()
if taskMeta, ok := getTaskMetadata(pod); ok {
ctx.notifyTaskComplete(ctx.GetApplication(taskMeta.ApplicationID), taskMeta.TaskID)
ctx.notifyTaskComplete(ctx.getApplication(taskMeta.ApplicationID), taskMeta.TaskID)
}

log.Log(log.ShimContext).Debug("removing pod from cache", zap.String("podName", pod.Name))
ctx.schedulerCache.RemovePod(pod)
}

func (ctx *Context) deleteForeignPod(pod *v1.Pod) {
ctx.lock.Lock()
defer ctx.lock.Unlock()
oldPod := ctx.schedulerCache.GetPod(string(pod.UID))
if oldPod == nil {
// if pod is not in scheduler cache, no node updates are needed
Expand Down Expand Up @@ -571,6 +583,8 @@ func (ctx *Context) addPriorityClass(obj interface{}) {
}

func (ctx *Context) updatePriorityClass(_, newObj interface{}) {
ctx.lock.Lock()
defer ctx.lock.Unlock()
if priorityClass := utils.Convert2PriorityClass(newObj); priorityClass != nil {
ctx.updatePriorityClassInternal(priorityClass)
}
Expand All @@ -581,6 +595,8 @@ func (ctx *Context) updatePriorityClassInternal(priorityClass *schedulingv1.Prio
}

func (ctx *Context) deletePriorityClass(obj interface{}) {
ctx.lock.Lock()
defer ctx.lock.Unlock()
log.Log(log.ShimContext).Debug("priorityClass deleted")
var priorityClass *schedulingv1.PriorityClass
switch t := obj.(type) {
Expand Down Expand Up @@ -646,6 +662,8 @@ func (ctx *Context) EventsToRegister(queueingHintFn framework.QueueingHintFn) []

// IsPodFitNode evaluates given predicates based on current context
func (ctx *Context) IsPodFitNode(name, node string, allocate bool) error {
ctx.lock.RLock()
defer ctx.lock.RUnlock()
pod := ctx.schedulerCache.GetPod(name)
if pod == nil {
return ErrorPodNotFound
Expand All @@ -666,6 +684,8 @@ func (ctx *Context) IsPodFitNode(name, node string, allocate bool) error {
}

func (ctx *Context) IsPodFitNodeViaPreemption(name, node string, allocations []string, startIndex int) (int, bool) {
ctx.lock.RLock()
defer ctx.lock.RUnlock()
if pod := ctx.schedulerCache.GetPod(name); pod != nil {
// if pod exists in cache, try to run predicates
if targetNode := ctx.schedulerCache.GetNode(node); targetNode != nil {
Expand Down Expand Up @@ -774,6 +794,8 @@ func (ctx *Context) bindPodVolumes(pod *v1.Pod) error {
// this way, the core can make allocation decisions with consideration of
// other assumed pods before they are actually bound to the node (bound is slow).
func (ctx *Context) AssumePod(name, node string) error {
ctx.lock.Lock()
defer ctx.lock.Unlock()
if pod := ctx.schedulerCache.GetPod(name); pod != nil {
// when add assumed pod, we make a copy of the pod to avoid
// modifying its original reference. otherwise, it may have
Expand Down Expand Up @@ -833,6 +855,8 @@ func (ctx *Context) AssumePod(name, node string) error {
// forget pod must be called when a pod is assumed to be running on a node,
// but then for some reason it is failed to bind or released.
func (ctx *Context) ForgetPod(name string) {
ctx.lock.Lock()
defer ctx.lock.Unlock()
if pod := ctx.schedulerCache.GetPod(name); pod != nil {
log.Log(log.ShimContext).Debug("forget pod", zap.String("pod", pod.Name))
ctx.schedulerCache.ForgetPod(pod)
Expand Down Expand Up @@ -949,6 +973,10 @@ func (ctx *Context) AddApplication(request *AddApplicationRequest) *Application
ctx.lock.Lock()
defer ctx.lock.Unlock()

return ctx.addApplication(request)
}

func (ctx *Context) addApplication(request *AddApplicationRequest) *Application {
log.Log(log.ShimContext).Debug("AddApplication", zap.Any("Request", request))
if app := ctx.getApplication(request.Metadata.ApplicationID); app != nil {
return app
Expand Down Expand Up @@ -1026,6 +1054,8 @@ func (ctx *Context) RemoveApplication(appID string) {

// this implements ApplicationManagementProtocol
func (ctx *Context) AddTask(request *AddTaskRequest) *Task {
ctx.lock.Lock()
defer ctx.lock.Unlock()
return ctx.addTask(request)
}

Expand Down Expand Up @@ -1074,8 +1104,8 @@ func (ctx *Context) addTask(request *AddTaskRequest) *Task {
}

func (ctx *Context) RemoveTask(appID, taskID string) {
ctx.lock.RLock()
defer ctx.lock.RUnlock()
ctx.lock.Lock()
defer ctx.lock.Unlock()
app, ok := ctx.applications[appID]
if !ok {
log.Log(log.ShimContext).Debug("Attempted to remove task from non-existent application", zap.String("appID", appID))
Expand All @@ -1085,7 +1115,9 @@ func (ctx *Context) RemoveTask(appID, taskID string) {
}

func (ctx *Context) getTask(appID string, taskID string) *Task {
app := ctx.GetApplication(appID)
ctx.lock.RLock()
defer ctx.lock.RUnlock()
app := ctx.getApplication(appID)
if app == nil {
log.Log(log.ShimContext).Debug("application is not found in the context",
zap.String("appID", appID))
Expand Down Expand Up @@ -1354,7 +1386,7 @@ func (ctx *Context) InitializeState() error {
log.Log(log.ShimContext).Error("failed to load nodes", zap.Error(err))
return err
}
acceptedNodes, err := ctx.registerNodes(nodes)
acceptedNodes, err := ctx.RegisterNodes(nodes)
if err != nil {
log.Log(log.ShimContext).Error("failed to register nodes", zap.Error(err))
return err
Expand Down Expand Up @@ -1474,11 +1506,17 @@ func (ctx *Context) registerNode(node *v1.Node) error {
return nil
}

func (ctx *Context) RegisterNodes(nodes []*v1.Node) ([]*v1.Node, error) {
ctx.lock.Lock()
defer ctx.lock.Unlock()
return ctx.registerNodes(nodes)
}

// registerNodes registers the nodes to the scheduler core.
// This method must be called while holding the Context write lock.
func (ctx *Context) registerNodes(nodes []*v1.Node) ([]*v1.Node, error) {
nodesToRegister := make([]*si.NodeInfo, 0)
pendingNodes := make(map[string]*v1.Node)
acceptedNodes := make([]*v1.Node, 0)
rejectedNodes := make([]*v1.Node, 0)

// Generate a NodeInfo object for each node and add to the registration request
for _, node := range nodes {
Expand All @@ -1497,12 +1535,34 @@ func (ctx *Context) registerNodes(nodes []*v1.Node) ([]*v1.Node, error) {
pendingNodes[node.Name] = node
}

var wg sync.WaitGroup
acceptedNodes, rejectedNodes, err := ctx.registerNodesInternal(nodesToRegister, pendingNodes)
if err != nil {
log.Log(log.ShimContext).Error("Failed to register nodes", zap.Error(err))
return nil, err
}

for _, node := range acceptedNodes {
// post a successful event to the node
events.GetRecorder().Eventf(node.DeepCopy(), nil, v1.EventTypeNormal, "NodeAccepted", "NodeAccepted",
fmt.Sprintf("node %s is accepted by the scheduler", node.Name))
}
for _, node := range rejectedNodes {
// post a failure event to the node
events.GetRecorder().Eventf(node.DeepCopy(), nil, v1.EventTypeWarning, "NodeRejected", "NodeRejected",
fmt.Sprintf("node %s is rejected by the scheduler", node.Name))
}

return acceptedNodes, nil
}

func (ctx *Context) registerNodesInternal(nodesToRegister []*si.NodeInfo, pendingNodes map[string]*v1.Node) ([]*v1.Node, []*v1.Node, error) {
acceptedNodes := make([]*v1.Node, 0)
rejectedNodes := make([]*v1.Node, 0)

var wg sync.WaitGroup
// initialize wait group with the number of responses we expect
wg.Add(len(pendingNodes))

// register with the dispatcher so that we can track our response
handlerID := fmt.Sprintf("%s-%d", registerNodeContextHandler, ctx.txnID.Add(1))
dispatcher.RegisterEventHandler(handlerID, dispatcher.EventTypeNode, func(event interface{}) {
nodeEvent, ok := event.(CachedSchedulerNodeEvent)
Expand Down Expand Up @@ -1534,24 +1594,17 @@ func (ctx *Context) registerNodes(nodes []*v1.Node) ([]*v1.Node, error) {
RmID: schedulerconf.GetSchedulerConf().ClusterID,
}); err != nil {
log.Log(log.ShimContext).Error("Failed to register nodes", zap.Error(err))
return nil, err
return nil, nil, err
}

// write lock must always be held at this point, releasing it while waiting to avoid any potential deadlocks
ctx.lock.Unlock()
defer ctx.lock.Lock()

// wait for all responses to accumulate
wg.Wait()

for _, node := range acceptedNodes {
// post a successful event to the node
events.GetRecorder().Eventf(node.DeepCopy(), nil, v1.EventTypeNormal, "NodeAccepted", "NodeAccepted",
fmt.Sprintf("node %s is accepted by the scheduler", node.Name))
}
for _, node := range rejectedNodes {
// post a failure event to the node
events.GetRecorder().Eventf(node.DeepCopy(), nil, v1.EventTypeWarning, "NodeRejected", "NodeRejected",
fmt.Sprintf("node %s is rejected by the scheduler", node.Name))
}

return acceptedNodes, nil
return acceptedNodes, rejectedNodes, nil
}

func (ctx *Context) decommissionNode(node *v1.Node) error {
Expand Down

0 comments on commit e9b05eb

Please sign in to comment.