Skip to content

Commit

Permalink
[SDK] test: add unit test for get_job method of the training_client (#…
Browse files Browse the repository at this point in the history
…2205)

Signed-off-by: Bobbins228 <[email protected]>
  • Loading branch information
Bobbins228 authored Sep 4, 2024
1 parent c64a5a6 commit 00eef58
Showing 1 changed file with 147 additions and 56 deletions.
203 changes: 147 additions & 56 deletions sdk/python/kubeflow/training/api/training_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from kubeflow.training.models import V1DeleteOptions
from kubernetes.client import (
ApiClient,
V1Container,
V1ObjectMeta,
V1PodSpec,
Expand All @@ -42,6 +43,27 @@ def conditional_error_handler(*args, **kwargs):
raise RuntimeError()


def serialize_k8s_object(obj):
api_client = ApiClient()
return api_client.sanitize_for_serialization(obj)


def get_namespaced_custom_object_response(*args, **kwargs):
if args[2] == "timeout":
raise multiprocessing.TimeoutError()
elif args[2] == "runtime":
raise RuntimeError()

# Create a serialized Job
serialized_job = serialize_k8s_object(generate_job_with_status(create_job()))

# Mock the thread and set it's return value to the serialized Job
mock_thread = Mock()
mock_thread.get.return_value = serialized_job

return mock_thread


def list_namespaced_pod_response(*args, **kwargs):
class MockResponse:
def get(self, timeout):
Expand Down Expand Up @@ -419,6 +441,111 @@ def __init__(self, kind) -> None:
),
]

test_data_get_job = [
(
"valid flow with default namespace and default timeout",
{"name": TEST_NAME},
SUCCESS,
),
(
"valid flow with all parameters set",
{
"name": TEST_NAME,
"namespace": TEST_NAME,
"job_kind": constants.PYTORCHJOB_KIND,
"timeout": 120,
},
SUCCESS,
),
(
"invalid flow with default namespace and a Job that doesn't exist",
{"name": TEST_NAME, "job_kind": constants.TFJOB_KIND},
RuntimeError,
),
(
"invalid flow incorrect parameter",
{"name": TEST_NAME, "test": "example"},
TypeError,
),
(
"invalid flow withincorrect value",
{"name": TEST_NAME, "job_kind": "FailJob"},
ValueError,
),
(
"runtime error case",
{
"name": TEST_NAME,
"namespace": "runtime",
"job_kind": constants.PYTORCHJOB_KIND,
},
RuntimeError,
),
(
"invalid flow with timeout error",
{"name": TEST_NAME, "namespace": TIMEOUT},
TimeoutError,
),
(
"invalid flow with runtime error",
{"name": TEST_NAME, "namespace": RUNTIME},
RuntimeError,
),
]


test_data_delete_job = [
(
"valid flow with default namespace",
{
"name": TEST_NAME,
},
SUCCESS,
),
(
"invalid extra parameter",
{"name": TEST_NAME, "namespace": TEST_NAME, "example": "test"},
TypeError,
),
(
"invalid job kind",
{"name": TEST_NAME, "job_kind": "invalid_job_kind"},
RuntimeError,
),
(
"job name missing",
{"namespace": TEST_NAME, "job_kind": constants.PYTORCHJOB_KIND},
TypeError,
),
(
"delete_namespaced_custom_object timeout error",
{"name": TEST_NAME, "namespace": TIMEOUT},
TimeoutError,
),
(
"delete_namespaced_custom_object runtime error",
{"name": TEST_NAME, "namespace": RUNTIME},
RuntimeError,
),
(
"valid flow",
{
"name": TEST_NAME,
"namespace": TEST_NAME,
"job_kind": constants.PYTORCHJOB_KIND,
},
SUCCESS,
),
(
"valid flow with delete options",
{
"name": TEST_NAME,
"delete_options": V1DeleteOptions(grace_period_seconds=30),
},
SUCCESS,
),
]


@pytest.fixture
def training_client():
Expand All @@ -428,6 +555,9 @@ def training_client():
create_namespaced_custom_object=Mock(side_effect=conditional_error_handler),
patch_namespaced_custom_object=Mock(side_effect=conditional_error_handler),
delete_namespaced_custom_object=Mock(side_effect=conditional_error_handler),
get_namespaced_custom_object=Mock(
side_effect=get_namespaced_custom_object_response
),
),
), patch(
"kubernetes.client.CoreV1Api",
Expand All @@ -436,8 +566,6 @@ def training_client():
),
), patch(
"kubernetes.config.load_kube_config", return_value=Mock()
), patch.object(
TrainingClient, "get_job", side_effect=get_job_response
):
client = TrainingClient(job_kind=constants.PYTORCHJOB_KIND)
yield client
Expand Down Expand Up @@ -536,69 +664,32 @@ def test_wait_for_job_conditions(training_client, test_name, kwargs, expected_ou
print("test execution complete")


test_data_delete_job = [
(
"valid flow with default namespace",
{
"name": TEST_NAME,
},
SUCCESS,
),
(
"invalid extra parameter",
{"name": TEST_NAME, "namespace": TEST_NAME, "example": "test"},
TypeError,
),
(
"invalid job kind",
{"name": TEST_NAME, "job_kind": "invalid_job_kind"},
RuntimeError,
),
(
"job name missing",
{"namespace": TEST_NAME, "job_kind": constants.PYTORCHJOB_KIND},
TypeError,
),
(
"delete_namespaced_custom_object timeout error",
{"name": TEST_NAME, "namespace": TIMEOUT},
TimeoutError,
),
(
"delete_namespaced_custom_object runtime error",
{"name": TEST_NAME, "namespace": RUNTIME},
RuntimeError,
),
(
"valid flow",
{
"name": TEST_NAME,
"namespace": TEST_NAME,
"job_kind": constants.PYTORCHJOB_KIND,
},
SUCCESS,
),
(
"valid flow with delete options",
{
"name": TEST_NAME,
"delete_options": V1DeleteOptions(grace_period_seconds=30),
},
SUCCESS,
),
]


@pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_delete_job)
def test_delete_job(training_client, test_name, kwargs, expected_output):
"""
test delete_job function of training client
"""
print("Executing test: ", test_name)

try:
training_client.delete_job(**kwargs)
assert expected_output == SUCCESS
except Exception as e:
assert type(e) is expected_output

print("test execution complete")


@pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_get_job)
def test_get_job(training_client, test_name, kwargs, expected_output):
"""
test get_job function of training client
"""
print("Executing test: ", test_name)

try:
training_client.get_job(**kwargs)
assert expected_output == SUCCESS
except Exception as e:
assert type(e) is expected_output

print("test execution complete")

0 comments on commit 00eef58

Please sign in to comment.