diff --git a/sdk/python/kubeflow/training/api/training_client_test.py b/sdk/python/kubeflow/training/api/training_client_test.py index f844a868c2..cfea20ad98 100644 --- a/sdk/python/kubeflow/training/api/training_client_test.py +++ b/sdk/python/kubeflow/training/api/training_client_test.py @@ -64,6 +64,25 @@ def get_namespaced_custom_object_response(*args, **kwargs): return mock_thread +def list_namespaced_custom_object_response(*args, **kwargs): + if args[2] == "timeout": + raise multiprocessing.TimeoutError() + elif args[2] == "runtime": + raise RuntimeError() + + # Create a serialized Job + serialized_job = serialize_k8s_object(generate_job_with_status(create_job())) + + # Mock the response containing a list of jobs + mock_response = {"items": [serialized_job]} + + # Mock the thread and set it's return value to the mock response + mock_thread = Mock() + mock_thread.get.return_value = mock_response + + return mock_thread + + def list_namespaced_pod_response(*args, **kwargs): class MockResponse: def get(self, timeout): @@ -493,6 +512,56 @@ def __init__(self, kind) -> None: ), ] +test_data_list_jobs = [ + ( + "valid flow with default namespace and default timeout", + {}, + SUCCESS, + ), + ( + "valid flow with all parameters set", + { + "namespace": TEST_NAME, + "job_kind": constants.PYTORCHJOB_KIND, + "timeout": 120, + }, + SUCCESS, + ), + ( + "invalid flow with default namespace and a Job that doesn't exist", + {"job_kind": constants.TFJOB_KIND}, + RuntimeError, + ), + ( + "invalid flow with incorrect parameter", + {"test": "example"}, + TypeError, + ), + ( + "invalid flow with incorrect job_kind value", + {"job_kind": "FailJob"}, + ValueError, + ), + ( + "runtime error case", + { + "namespace": RUNTIME, + "job_kind": constants.PYTORCHJOB_KIND, + }, + RuntimeError, + ), + ( + "invalid flow with timeout error", + {"namespace": TIMEOUT}, + TimeoutError, + ), + ( + "invalid flow with runtime error", + {"namespace": RUNTIME}, + RuntimeError, + ), +] + test_data_delete_job = [ ( @@ -558,6 +627,9 @@ def training_client(): get_namespaced_custom_object=Mock( side_effect=get_namespaced_custom_object_response ), + list_namespaced_custom_object=Mock( + side_effect=list_namespaced_custom_object_response + ), ), ), patch( "kubernetes.client.CoreV1Api", @@ -693,3 +765,19 @@ def test_get_job(training_client, test_name, kwargs, expected_output): assert type(e) is expected_output print("test execution complete") + + +@pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_list_jobs) +def test_list_jobs(training_client, test_name, kwargs, expected_output): + """ + test list_jobs function of training client + """ + print("Executing test: ", test_name) + + try: + training_client.list_jobs(**kwargs) + assert expected_output == SUCCESS + except Exception as e: + assert type(e) is expected_output + + print("test execution complete")