Skip to content

Commit

Permalink
Fixed after review
Browse files Browse the repository at this point in the history
Signed-off-by: yelias <[email protected]>
  • Loading branch information
yelias committed Jul 21, 2024
1 parent bde097a commit 2e6661a
Showing 1 changed file with 39 additions and 39 deletions.
78 changes: 39 additions & 39 deletions sdk/python/kubeflow/training/api/training_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get(self, timeout):
if args[0] == "runtime":
raise Exception()
return Mock(items=LIST_RESPONSE)

return MockResponse()


Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(self, kind) -> None:
self.kind = kind


test_data_for_create_job = [
test_data_create_job = [
(
"invalid extra parameter",
{"job": create_job(), "namespace": TEST_NAME, "base_image": "test_image"},
Expand Down Expand Up @@ -193,46 +194,40 @@ def __init__(self, kind) -> None:
),
]

test_data_for_get_job_pods = [
test_data_get_job_pods = [
(
"valid flow with default namespace and default timeout",
{
"name": TEST_NAME,
},
f"{constants.JOB_NAME_LABEL}={TEST_NAME}",
LIST_RESPONSE
LIST_RESPONSE,
),
(
"invalid replica_type",
{
"name": TEST_NAME,
"replica_type": "invalid_replica_type"
},
{"name": TEST_NAME, "replica_type": "invalid_replica_type"},
"Label not relevant",
ValueError
ValueError,
),
(
"invalid replica_type (uppercase)",
{
"name": TEST_NAME,
"replica_type": constants.REPLICA_TYPE_WORKER
},
{"name": TEST_NAME, "replica_type": constants.REPLICA_TYPE_WORKER},
"Label not relevant",
ValueError
ValueError,
),
(
"valid flow with specific timeout, replica_index, replica_type and master role",
{
"name": TEST_NAME,
"namespace": "test_namespace",
"timeout": 60,
"is_master": True,
"is_master": True,
"replica_type": constants.REPLICA_TYPE_MASTER.lower(),
"replica_index": 0
"replica_index": 0,
},
f"{constants.JOB_NAME_LABEL}={TEST_NAME},{constants.JOB_ROLE_LABEL}={constants.JOB_ROLE_MASTER}"
f",{constants.REPLICA_TYPE_LABEL}={constants.REPLICA_TYPE_MASTER.lower()},{constants.REPLICA_INDEX_LABEL}=0",
LIST_RESPONSE
LIST_RESPONSE,
),
(
"invalid flow with TimeoutError",
Expand All @@ -241,7 +236,7 @@ def __init__(self, kind) -> None:
"namespace": "timeout",
},
"Label not relevant",
TimeoutError
TimeoutError,
),
(
"invalid flow with RuntimeError",
Expand All @@ -250,34 +245,33 @@ def __init__(self, kind) -> None:
"namespace": "runtime",
},
"Label not relevant",
RuntimeError
)
RuntimeError,
),
]


@pytest.fixture
def training_client():
with patch(
"kubernetes.client.CustomObjectsApi",
return_value=Mock(
create_namespaced_custom_object=Mock(
side_effect=create_namespaced_custom_object_response
)
),
), patch("kubernetes.client.CoreV1Api",
return_value=Mock(
list_namespaced_pod=Mock(
side_effect=list_namespaced_pod_response
)
)
), patch("kubernetes.config.load_kube_config",
return_value=Mock()
):
"kubernetes.client.CustomObjectsApi",
return_value=Mock(
create_namespaced_custom_object=Mock(
side_effect=create_namespaced_custom_object_response
)
),
), patch(
"kubernetes.client.CoreV1Api",
return_value=Mock(
list_namespaced_pod=Mock(side_effect=list_namespaced_pod_response)
),
), patch(
"kubernetes.config.load_kube_config", return_value=Mock()
):
client = TrainingClient(job_kind=constants.PYTORCHJOB_KIND)
yield client


@pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_for_create_job)
@pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_create_job)
def test_create_job(training_client, test_name, kwargs, expected_output):
"""
test create_job function of training client
Expand All @@ -291,8 +285,13 @@ def test_create_job(training_client, test_name, kwargs, expected_output):
print("test execution complete")


@pytest.mark.parametrize("test_name,kwargs,expected_label_selector,expected_output", test_data_for_get_job_pods,)
def test_get_job_pods(training_client, test_name, kwargs, expected_label_selector, expected_output):
@pytest.mark.parametrize(
"test_name,kwargs,expected_label_selector,expected_output",
test_data_get_job_pods,
)
def test_get_job_pods(
training_client, test_name, kwargs, expected_label_selector, expected_output
):
"""
test get_job_pods function of training client
"""
Expand All @@ -303,9 +302,10 @@ def test_get_job_pods(training_client, test_name, kwargs, expected_label_selecto
training_client.core_api.list_namespaced_pod.assert_called_with(
kwargs.get("namespace", constants.DEFAULT_NAMESPACE),
label_selector=expected_label_selector,
async_req=True)
async_req=True,
)
assert out[0].pop("timeout") == kwargs.get("timeout", constants.DEFAULT_TIMEOUT)
assert out == expected_output
except Exception as e:
assert type(e) is expected_output
print("test execution complete")
print("test execution complete")

0 comments on commit 2e6661a

Please sign in to comment.