Skip to content

Commit

Permalink
Don't leave pods running when a job completes. (#512)
Browse files Browse the repository at this point in the history
* Don't leave pods running when a job completes.

* We originally did this to preserve the logs.
* But this ends up leaving pods running consuming resources.
* The fix is straightforward
  * Transition to the cleanup phase before transitioning to the done phase.

Fix #128

* Don't teardown the cluster.

* Don't set phase to cleanup when job is running.

* Should only call get status if we are in creating or running phase.

* Update the E2E test

* Check that pod/service event creations are recorded
* Check that pods are deleted when job ends.

* Fix lint.
  • Loading branch information
jlewi authored and k8s-ci-robot committed Mar 29, 2018
1 parent 41a20d4 commit a7511ff
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 73 deletions.
2 changes: 1 addition & 1 deletion pkg/trainer/replicas.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ func (s *TFReplicaSet) SyncPods() error {
}

if len(pl.Items) == 0 {
s.contextLogger.Infof("Job %s missing pod for replica %s index %s, creating a new one.", s.Job.name(), string(s.Spec.TFReplicaType), index)
s.contextLogger.Infof("Job %v missing pod for replica %v index %v, creating a new one.", s.Job.name(), string(s.Spec.TFReplicaType), index)
// Create the pod
createdPod, err := s.CreatePodWithIndex(index)

Expand Down
70 changes: 35 additions & 35 deletions pkg/trainer/training.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,52 +389,52 @@ func (j *TrainingJob) Reconcile(config *tfv1alpha1.ControllerConfig, enableGangS
j.contextLogger.Errorf("SyncServices error: %v", err)
}
}
}

if err := j.updateCRDStatus(); err != nil {
j.contextLogger.Warningf("Job %v; failed to update status error: %v", j.job.ObjectMeta.Name, err)
return err
}
if err := j.updateCRDStatus(); err != nil {
j.contextLogger.Warningf("Job %v; failed to update status error: %v", j.job.ObjectMeta.Name, err)
return err
}

// Call GetStatus in each reconcile loop
state, replicaStatuses, err := j.GetStatus()
// Call GetStatus in each reconcile loop
state, replicaStatuses, err := j.GetStatus()

j.status.ReplicaStatuses = replicaStatuses
if err != nil {
j.contextLogger.Errorf("GetStatus() for job %v returned error: %v", j.job.ObjectMeta.Name, err)
return err
}
j.status.ReplicaStatuses = replicaStatuses
if err != nil {
j.contextLogger.Errorf("GetStatus() for job %v returned error: %v", j.job.ObjectMeta.Name, err)
return err
}

// TODO(jlewi): We should update the Phase if we detect the job is done.
if state == tfv1alpha1.StateFailed {
j.contextLogger.Errorf("Master failed Job: %v.", j.job.ObjectMeta.Name)
j.status.Phase = tfv1alpha1.TFJobPhaseDone
j.status.State = tfv1alpha1.StateFailed
} else if state == tfv1alpha1.StateSucceeded {
j.contextLogger.Infof("Master succeeded Job: %v.", j.job.ObjectMeta.Name)
j.status.Phase = tfv1alpha1.TFJobPhaseDone
j.status.State = tfv1alpha1.StateSucceeded
} else if state == tfv1alpha1.StateRunning {
j.contextLogger.Infof("Master running Job: %v.", j.job.ObjectMeta.Name)
j.status.Phase = tfv1alpha1.TFJobPhaseRunning
j.status.State = tfv1alpha1.StateRunning
} else {
j.contextLogger.Infof("Job %v status=%v", j.job.ObjectMeta.Name, util.Pformat(j.status))
}
// TODO(jlewi): We should update the Phase if we detect the job is done.
if state == tfv1alpha1.StateFailed {
j.contextLogger.Errorf("Master failed Job: %v.", j.job.ObjectMeta.Name)
j.status.Phase = tfv1alpha1.TFJobPhaseCleanUp
j.status.State = tfv1alpha1.StateFailed
} else if state == tfv1alpha1.StateSucceeded {
j.contextLogger.Infof("Master succeeded Job: %v.", j.job.ObjectMeta.Name)
j.status.Phase = tfv1alpha1.TFJobPhaseCleanUp
j.status.State = tfv1alpha1.StateSucceeded
} else if state == tfv1alpha1.StateRunning {
j.contextLogger.Infof("Master running Job: %v.", j.job.ObjectMeta.Name)
j.status.Phase = tfv1alpha1.TFJobPhaseRunning
j.status.State = tfv1alpha1.StateRunning
} else {
j.contextLogger.Infof("Job %v status=%v", j.job.ObjectMeta.Name, util.Pformat(j.status))
}

// If the phase changed we should update the CRD.
if err := j.updateCRDStatus(); err != nil {
j.contextLogger.Warningf("Job %v, failed to update CRD status error: %v", j.job.ObjectMeta.Name, err)
return err
// If the phase changed we should update the CRD.
if err := j.updateCRDStatus(); err != nil {
j.contextLogger.Warningf("Job %v, failed to update CRD status error: %v", j.job.ObjectMeta.Name, err)
return err
}
}

if j.job.Status.Phase == tfv1alpha1.TFJobPhaseCleanUp {
if cErr := j.deleteResources(); cErr != nil {
j.contextLogger.Errorf("Job %v trainingJob.Delete() error; %v", j.job.ObjectMeta.Name, cErr)
// Return an error so that we stay in phase cleanup and retry.
return cErr
}
// j.status.SetPhase(spec.TFJobPhaseDone)
// Return from run because we want to stop reconciling the object.
return nil
j.status.Phase = tfv1alpha1.TFJobPhaseDone
}

// updateCRDStatus will update the status of the CRD with c.Status if c.Status
Expand Down
165 changes: 132 additions & 33 deletions py/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import json
import os
import re
import time
import uuid

Expand Down Expand Up @@ -58,6 +59,39 @@ def wait_for_delete(client,
time.sleep(polling_interval.seconds)


def wait_for_pods_to_be_deleted(client,
namespace,
pod_selector,
timeout=datetime.timedelta(minutes=5),
polling_interval=datetime.timedelta(
seconds=30)):
"""Wait for the specified job to be deleted.
Args:
client: K8s api client.
namespace: Namespace.
pod_selector: Selector for the pods.
timeout: How long to wait for the job.
polling_interval: How often to poll for the status of the job.
status_callback: (Optional): Callable. If supplied this callable is
invoked after we poll the job. Callable takes a single argument which
is the job.
"""
end_time = datetime.datetime.now() + timeout
while True:
pods = list_pods(client, namespace, pod_selector)

logging.info("%s pods matched %s pods", len(pods.items), pod_selector)

if not pods.items:
return

if datetime.datetime.now() + polling_interval > end_time:
raise util.TimeoutError("Timeout waiting for pods to be deleted.")

time.sleep(polling_interval.seconds)


def get_labels(name, runtime_id, replica_type=None, replica_index=None):
"""Return labels.
"""
Expand Down Expand Up @@ -108,6 +142,73 @@ def list_pods(client, namespace, label_selector):
raise e


def get_events(client, namespace, uid):
"""Get the events for the provided object."""
core = k8s_client.CoreV1Api(client)
try:
# We can't filter by labels because events don't appear to have anyone
# and I didn't see an easy way to get them.
events = core.list_namespaced_event(namespace)
except rest.ApiException as e:
message = ""
if e.message:
message = e.message
if e.body:
try:
body = json.loads(e.body)
except ValueError:
# There was a problem parsing the body of the response as json.
logging.error(
("Exception when calling DefaultApi->"
"apis_fqdn_v1_namespaces_namespace_resource_post. body: %s"), e.body)
raise
message = body.get("message")

logging.error(("Exception when calling DefaultApi->"
"apis_fqdn_v1_namespaces_namespace_resource_post: %s"),
message)
raise e

matching = []

for e in events.items:
if e.involved_object.uid != uid:
continue
matching.append(e)

return matching


def parse_events(events):
"""Parse events.
Args:
events: List of events.
Returns
pods_created: Set of unique pod names created.
services_created: Set of unique services created.
"""
pattern = re.compile("Created.*(pod|Service).*: (.*)", re.IGNORECASE)

pods = set()
services = set()
for e in events:
m = re.match(pattern, e.message)
if not m:
continue

kind = m.group(1)
name = m.group(2)

if kind.lower() == "pod":
pods.add(name)
elif kind.lower() == "service":
services.add(name)

return pods, services


def run_test(args): # pylint: disable=too-many-branches,too-many-statements
"""Run a test."""
gcs_client = storage.Client(project=args.project)
Expand Down Expand Up @@ -178,53 +279,52 @@ def run_test(args): # pylint: disable=too-many-branches,too-many-statements

if results.get("status", {}).get("state", {}).lower() != "succeeded":
t.failure = "Trial {0} Job {1} in namespace {2} in state {3}".format(
trial, name, namespace,
results.get("status", {}).get("state", None))
trial, name, namespace, results.get("status", {}).get("state", None))
logging.error(t.failure)
break

runtime_id = results.get("spec", {}).get("RuntimeId")
logging.info("Trial %s Job %s in namespace %s runtime ID %s", trial, name,
namespace, runtime_id)

# TODO(jlewi): We should check that pods were created for each replica
uid = results.get("metadata", {}).get("uid")
events = get_events(api_client, namespace, uid)
created_pods, created_services = parse_events(events)

num_expected = 0
for replica in results.get("spec", {}).get("replicaSpecs", []):
num_expected += replica.get("replicas", 0)

creation_failures = []
if len(created_pods) != num_expected:
message = ("Expected {0} pods to be created but only "
"got {1} create events.").format(num_expected,
len(created_pods))
creation_failures.append(message)

if len(created_services) != num_expected:
message = ("Expected {0} services to be created but only "
"got {1} create events.").format(num_expected,
len(created_services))
creation_failures.append(message)

if creation_failures:
t.failure = "Trial {0} Job {1} in namespace {2}: {3}".format(
trial, name, namespace, ", ".join(creation_failures))
logging.error(t.failure)
break
pod_labels = get_labels(name, runtime_id)
pod_selector = to_selector(pod_labels)
pods = list_pods(api_client, namespace, pod_selector)

logging.info("Trial %s selector: %s matched %s pods", trial, pod_selector,
len(pods.items))

if not pods.items:
t.failure = ("Trial {0} Job {1} in namespace {2} no pods found for "
" selector {3}").format(trial, name, namespace,
pod_selector)
logging.error(t.failure)
break
wait_for_pods_to_be_deleted(api_client, namespace, pod_selector)

tf_job_client.delete_tf_job(api_client, namespace, name)

logging.info("Waiting for job %s in namespaces %s to be deleted.", name, namespace)
logging.info("Waiting for job %s in namespaces %s to be deleted.", name,
namespace)
wait_for_delete(
api_client, namespace, name, status_callback=tf_job_client.log_status)

# Verify the pods have been deleted. tf_job_client uses foreground
# deletion so there shouldn't be any resources for the job left
# once the job is gone.
pods = list_pods(api_client, namespace, pod_selector)

logging.info("Trial %s selector: %s matched %s pods", trial, pod_selector,
len(pods.items))

if pods.items:
t.failure = ("Trial {0} Job {1} in namespace {2} pods found for "
" selector {3}; pods\n{4}").format(trial, name, namespace,
pod_selector, pods)
logging.error(t.failure)
break

logging.info("Trial %s all pods deleted.", trial)

# TODO(jlewi):
# Here are some validation checks to run:
# 1. Check that all resources are garbage collected.
Expand Down Expand Up @@ -312,8 +412,7 @@ def main(): # pylint: disable=too-many-locals
level=logging.INFO,
format=('%(levelname)s|%(asctime)s'
'|%(pathname)s|%(lineno)d| %(message)s'),
datefmt='%Y-%m-%dT%H:%M:%S',
)
datefmt='%Y-%m-%dT%H:%M:%S',)

util.maybe_activate_service_account()

Expand Down
9 changes: 5 additions & 4 deletions test/workflows/components/workflows.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,11 @@
{
name: "exit-handler",
steps: [
[{
name: "teardown-cluster",
template: "teardown-cluster",
}],
// DO NOT SUBMIT comment out to facilitate debugging.
//[{
// name: "teardown-cluster",
// template: "teardown-cluster",
//}],
[{
name: "copy-artifacts",
template: "copy-artifacts",
Expand Down

0 comments on commit a7511ff

Please sign in to comment.