diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index 59023d8ef..7c94f9502 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -71,8 +71,10 @@ def __init__(self, config: ClusterConfiguration): """ self.config = config self.app_wrapper_yaml = self.create_app_wrapper() - self.app_wrapper_name = self.app_wrapper_yaml.split(".")[0] self._job_submission_client = None + self.app_wrapper_name = self.app_wrapper_yaml.replace(".yaml", "").split("/")[ + -1 + ] @property def _client_headers(self): diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index cf9686c43..a6aae3082 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -19,6 +19,7 @@ import yaml import sys +import os import argparse import uuid from kubernetes import client, config @@ -506,8 +507,14 @@ def disable_raycluster_tls(resources): def write_user_appwrapper(user_yaml, output_file_name): + # Create the directory if it doesn't exist + directory_path = os.path.dirname(output_file_name) + if not os.path.exists(directory_path): + os.makedirs(directory_path) + with open(output_file_name, "w") as outfile: yaml.dump(user_yaml, outfile, default_flow_style=False) + print(f"Written to: {output_file_name}") @@ -675,7 +682,8 @@ def generate_appwrapper( if openshift_oauth: enable_openshift_oauth(user_yaml, cluster_name, namespace) - outfile = appwrapper_name + ".yaml" + directory_path = os.path.expanduser("~/.codeflare/appwrapper/") + outfile = os.path.join(directory_path, appwrapper_name + ".yaml") if not mcad: write_components(user_yaml, outfile) else: diff --git a/tests/unit_test.py b/tests/unit_test.py index 620476df7..44bcddea0 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -24,6 +24,7 @@ from codeflare_sdk.cluster import cluster parent = Path(__file__).resolve().parents[1] +aw_dir = os.path.expanduser("~/.codeflare/appwrapper/") sys.path.append(str(parent) + "/src") from kubernetes import client, config @@ -261,10 +262,12 @@ def test_config_creation(): def test_cluster_creation(mocker): cluster = createClusterWithConfig(mocker) - assert cluster.app_wrapper_yaml == "unit-test-cluster.yaml" + assert cluster.app_wrapper_yaml == f"{aw_dir}unit-test-cluster.yaml" assert cluster.app_wrapper_name == "unit-test-cluster" assert filecmp.cmp( - "unit-test-cluster.yaml", f"{parent}/tests/test-case.yaml", shallow=True + f"{aw_dir}unit-test-cluster.yaml", + f"{parent}/tests/test-case.yaml", + shallow=True, ) @@ -290,10 +293,10 @@ def test_cluster_creation_no_mcad(mocker): config.name = "unit-test-cluster-ray" config.mcad = False cluster = Cluster(config) - assert cluster.app_wrapper_yaml == "unit-test-cluster-ray.yaml" + assert cluster.app_wrapper_yaml == f"{aw_dir}unit-test-cluster-ray.yaml" assert cluster.app_wrapper_name == "unit-test-cluster-ray" assert filecmp.cmp( - "unit-test-cluster-ray.yaml", + f"{aw_dir}unit-test-cluster-ray.yaml", f"{parent}/tests/test-case-no-mcad.yamls", shallow=True, ) @@ -313,10 +316,12 @@ def test_cluster_creation_priority(mocker): return_value={"spec": {"domain": "apps.cluster.awsroute.org"}}, ) cluster = Cluster(config) - assert cluster.app_wrapper_yaml == "prio-test-cluster.yaml" + assert cluster.app_wrapper_yaml == f"{aw_dir}prio-test-cluster.yaml" assert cluster.app_wrapper_name == "prio-test-cluster" assert filecmp.cmp( - "prio-test-cluster.yaml", f"{parent}/tests/test-case-prio.yaml", shallow=True + f"{aw_dir}prio-test-cluster.yaml", + f"{parent}/tests/test-case-prio.yaml", + shallow=True, ) @@ -335,7 +340,7 @@ def test_default_cluster_creation(mocker): ) cluster = Cluster(default_config) - assert cluster.app_wrapper_yaml == "unit-test-default-cluster.yaml" + assert cluster.app_wrapper_yaml == f"{aw_dir}unit-test-default-cluster.yaml" assert cluster.app_wrapper_name == "unit-test-default-cluster" assert cluster.config.namespace == "opendatahub" @@ -365,13 +370,13 @@ def arg_check_apply_effect(group, version, namespace, plural, body, *args): if plural == "appwrappers": assert group == "workload.codeflare.dev" assert version == "v1beta1" - with open("unit-test-cluster.yaml") as f: + with open(f"{aw_dir}unit-test-cluster.yaml") as f: aw = yaml.load(f, Loader=yaml.FullLoader) assert body == aw elif plural == "rayclusters": assert group == "ray.io" assert version == "v1alpha1" - with open("unit-test-cluster-ray.yaml") as f: + with open(f"{aw_dir}unit-test-cluster-ray.yaml") as f: yamls = yaml.load_all(f, Loader=yaml.FullLoader) for resource in yamls: if resource["kind"] == "RayCluster": @@ -379,7 +384,7 @@ def arg_check_apply_effect(group, version, namespace, plural, body, *args): elif plural == "routes": assert group == "route.openshift.io" assert version == "v1" - with open("unit-test-cluster-ray.yaml") as f: + with open(f"{aw_dir}unit-test-cluster-ray.yaml") as f: yamls = yaml.load_all(f, Loader=yaml.FullLoader) for resource in yamls: if resource["kind"] == "Route": @@ -2408,7 +2413,7 @@ def parse_j(cmd): def test_AWManager_creation(): - testaw = AWManager("test.yaml") + testaw = AWManager(f"{aw_dir}test.yaml") assert testaw.name == "test" assert testaw.namespace == "ns" assert testaw.submitted == False @@ -2432,7 +2437,7 @@ def arg_check_aw_apply_effect(group, version, namespace, plural, body, *args): assert version == "v1beta1" assert namespace == "ns" assert plural == "appwrappers" - with open("test.yaml") as f: + with open(f"{aw_dir}test.yaml") as f: aw = yaml.load(f, Loader=yaml.FullLoader) assert body == aw assert args == tuple() @@ -2448,7 +2453,7 @@ def arg_check_aw_del_effect(group, version, namespace, plural, name, *args): def test_AWManager_submit_remove(mocker, capsys): - testaw = AWManager("test.yaml") + testaw = AWManager(f"{aw_dir}test.yaml") testaw.remove() captured = capsys.readouterr() assert ( @@ -2876,13 +2881,12 @@ def test_gen_app_wrapper_with_oauth(mocker: MockerFixture): # Make sure to always keep this function last def test_cleanup(): - os.remove("unit-test-cluster.yaml") - os.remove("prio-test-cluster.yaml") - os.remove("unit-test-default-cluster.yaml") - os.remove("unit-test-cluster-ray.yaml") - os.remove("test.yaml") - os.remove("raytest2.yaml") - os.remove("quicktest.yaml") + os.remove(f"{aw_dir}unit-test-cluster.yaml") + os.remove(f"{aw_dir}prio-test-cluster.yaml") + os.remove(f"{aw_dir}unit-test-default-cluster.yaml") + os.remove(f"{aw_dir}test.yaml") + os.remove(f"{aw_dir}raytest2.yaml") + os.remove(f"{aw_dir}quicktest.yaml") os.remove("tls-cluster-namespace/ca.crt") os.remove("tls-cluster-namespace/tls.crt") os.remove("tls-cluster-namespace/tls.key")