diff --git a/sdk/python/kubeflow/training/api/training_client_test.py b/sdk/python/kubeflow/training/api/training_client_test.py index 32cf403b64..04187ac20c 100644 --- a/sdk/python/kubeflow/training/api/training_client_test.py +++ b/sdk/python/kubeflow/training/api/training_client_test.py @@ -38,6 +38,7 @@ def get(self, timeout): if args[0] == "runtime": raise Exception() return Mock(items=LIST_RESPONSE) + return MockResponse() @@ -103,7 +104,7 @@ def __init__(self, kind) -> None: self.kind = kind -test_data_for_create_job = [ +test_data_create_job = [ ( "invalid extra parameter", {"job": create_job(), "namespace": TEST_NAME, "base_image": "test_image"}, @@ -193,32 +194,26 @@ def __init__(self, kind) -> None: ), ] -test_data_for_get_job_pods = [ +test_data_get_job_pods = [ ( "valid flow with default namespace and default timeout", { "name": TEST_NAME, }, f"{constants.JOB_NAME_LABEL}={TEST_NAME}", - LIST_RESPONSE + LIST_RESPONSE, ), ( "invalid replica_type", - { - "name": TEST_NAME, - "replica_type": "invalid_replica_type" - }, + {"name": TEST_NAME, "replica_type": "invalid_replica_type"}, "Label not relevant", - ValueError + ValueError, ), ( "invalid replica_type (uppercase)", - { - "name": TEST_NAME, - "replica_type": constants.REPLICA_TYPE_WORKER - }, + {"name": TEST_NAME, "replica_type": constants.REPLICA_TYPE_WORKER}, "Label not relevant", - ValueError + ValueError, ), ( "valid flow with specific timeout, replica_index, replica_type and master role", @@ -226,13 +221,13 @@ def __init__(self, kind) -> None: "name": TEST_NAME, "namespace": "test_namespace", "timeout": 60, - "is_master": True, + "is_master": True, "replica_type": constants.REPLICA_TYPE_MASTER.lower(), - "replica_index": 0 + "replica_index": 0, }, f"{constants.JOB_NAME_LABEL}={TEST_NAME},{constants.JOB_ROLE_LABEL}={constants.JOB_ROLE_MASTER}" f",{constants.REPLICA_TYPE_LABEL}={constants.REPLICA_TYPE_MASTER.lower()},{constants.REPLICA_INDEX_LABEL}=0", - LIST_RESPONSE + LIST_RESPONSE, ), ( "invalid flow with TimeoutError", @@ -241,7 +236,7 @@ def __init__(self, kind) -> None: "namespace": "timeout", }, "Label not relevant", - TimeoutError + TimeoutError, ), ( "invalid flow with RuntimeError", @@ -250,34 +245,33 @@ def __init__(self, kind) -> None: "namespace": "runtime", }, "Label not relevant", - RuntimeError - ) + RuntimeError, + ), ] @pytest.fixture def training_client(): with patch( - "kubernetes.client.CustomObjectsApi", - return_value=Mock( - create_namespaced_custom_object=Mock( - side_effect=create_namespaced_custom_object_response - ) - ), - ), patch("kubernetes.client.CoreV1Api", - return_value=Mock( - list_namespaced_pod=Mock( - side_effect=list_namespaced_pod_response - ) - ) - ), patch("kubernetes.config.load_kube_config", - return_value=Mock() - ): + "kubernetes.client.CustomObjectsApi", + return_value=Mock( + create_namespaced_custom_object=Mock( + side_effect=create_namespaced_custom_object_response + ) + ), + ), patch( + "kubernetes.client.CoreV1Api", + return_value=Mock( + list_namespaced_pod=Mock(side_effect=list_namespaced_pod_response) + ), + ), patch( + "kubernetes.config.load_kube_config", return_value=Mock() + ): client = TrainingClient(job_kind=constants.PYTORCHJOB_KIND) yield client -@pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_for_create_job) +@pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_create_job) def test_create_job(training_client, test_name, kwargs, expected_output): """ test create_job function of training client @@ -291,8 +285,13 @@ def test_create_job(training_client, test_name, kwargs, expected_output): print("test execution complete") -@pytest.mark.parametrize("test_name,kwargs,expected_label_selector,expected_output", test_data_for_get_job_pods,) -def test_get_job_pods(training_client, test_name, kwargs, expected_label_selector, expected_output): +@pytest.mark.parametrize( + "test_name,kwargs,expected_label_selector,expected_output", + test_data_get_job_pods, +) +def test_get_job_pods( + training_client, test_name, kwargs, expected_label_selector, expected_output +): """ test get_job_pods function of training client """ @@ -303,9 +302,10 @@ def test_get_job_pods(training_client, test_name, kwargs, expected_label_selecto training_client.core_api.list_namespaced_pod.assert_called_with( kwargs.get("namespace", constants.DEFAULT_NAMESPACE), label_selector=expected_label_selector, - async_req=True) + async_req=True, + ) assert out[0].pop("timeout") == kwargs.get("timeout", constants.DEFAULT_TIMEOUT) assert out == expected_output except Exception as e: assert type(e) is expected_output - print("test execution complete") \ No newline at end of file + print("test execution complete")