diff --git a/notebooks/scenarios/bigquery/000-start-and-configure-server-and-admins.ipynb b/notebooks/scenarios/bigquery/000-start-and-configure-server-and-admins.ipynb index 7320426de06..3aabe924a4c 100644 --- a/notebooks/scenarios/bigquery/000-start-and-configure-server-and-admins.ipynb +++ b/notebooks/scenarios/bigquery/000-start-and-configure-server-and-admins.ipynb @@ -211,7 +211,7 @@ "metadata": {}, "outputs": [], "source": [ - "create_checkpoint(root_client)" + "create_checkpoint(name=\"000-start-and-config\", client=root_client)" ] }, { @@ -231,13 +231,6 @@ "source": [ "server.land()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -256,7 +249,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb b/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb index 69b16576fdf..65ad4ae6dde 100644 --- a/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb +++ b/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb @@ -93,7 +93,7 @@ "outputs": [], "source": [ "load_from_checkpoint(\n", - " prev_nb_filename=\"000-start-and-configure-server-and-admins\",\n", + " name=\"000-start-and-config\",\n", " client=server.client,\n", " root_email=ROOT_EMAIL,\n", " root_password=ROOT_PASSWORD,\n", @@ -357,14 +357,6 @@ "source": [ "server.land()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "29", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -383,7 +375,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/packages/syft/src/syft/util/test_helpers/checkpoint.py b/packages/syft/src/syft/util/test_helpers/checkpoint.py index 17dece9d38c..99b0a3b1e52 100644 --- a/packages/syft/src/syft/util/test_helpers/checkpoint.py +++ b/packages/syft/src/syft/util/test_helpers/checkpoint.py @@ -1,19 +1,14 @@ # stdlib import datetime -import json import os from pathlib import Path -# third party -import ipykernel - # syft absolute from syft import SyftError from syft import SyftException from syft.client.client import SyftClient from syft.service.user.user_roles import ServiceRole from syft.util.util import get_root_data_path -from syft.util.util import is_interpreter_jupyter # relative from ...server.env import get_default_root_email @@ -24,53 +19,25 @@ CHECKPOINT_DIR_PREFIX = "chkpt" -def get_notebook_name_from_pytest_env() -> str | None: - """ - Returns the notebook file name from the test environment variable 'PYTEST_CURRENT_TEST'. - If not available, returns None. - """ - pytest_current_test = os.environ.get("PYTEST_CURRENT_TEST", "") - # Split by "::" and return the first part, which is the file path - return os.path.basename(pytest_current_test.split("::")[0]) - - -def current_nbname() -> Path: - """Retrieve the current Jupyter notebook name.""" - curr_kernel_file = Path(ipykernel.get_connection_file()) - kernel_file = json.loads(curr_kernel_file.read_text()) - nb_name = kernel_file.get("jupyter_session", "") - if not nb_name: - nb_name = get_notebook_name_from_pytest_env() - return Path(nb_name) - - def root_checkpoint_path() -> Path: return get_root_data_path() / CHECKPOINT_ROOT -def checkpoint_parent_dir(server_uid: str, nb_name: str | None = None) -> Path: - """Return the checkpoint directory for the current notebook and server.""" - if is_interpreter_jupyter: - nb_name = nb_name if nb_name else current_nbname().stem - return Path(f"{nb_name}/{server_uid}") if nb_name else Path(server_uid) - return Path(server_uid) - - -def get_checkpoints_dir(server_uid: str, nb_name: str) -> Path: - return root_checkpoint_path() / checkpoint_parent_dir(server_uid, nb_name) - +def get_checkpoint_parent_dir(server_uid: str, chkpt_name: str) -> Path: + return root_checkpoint_path() / chkpt_name / server_uid -def get_checkpoint_dir( - server_uid: str, checkpoint_dir: str, nb_name: str | None = None -) -> Path: - return get_checkpoints_dir(server_uid, nb_name) / checkpoint_dir - -def create_checkpoint_dir(server_uid: str) -> Path: - """Create a checkpoint directory for the current notebook and server.""" +def create_checkpoint_dir(server_uid: str, chkpt_name: str) -> Path: + """Create a checkpoint directory by chkpt_name and server_uid.""" timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") checkpoint_dir = f"{CHECKPOINT_DIR_PREFIX}_{timestamp}" - checkpoint_full_path = get_checkpoint_dir(server_uid, checkpoint_dir=checkpoint_dir) + checkpoint_parent_dir = get_checkpoint_parent_dir( + server_uid=server_uid, chkpt_name=chkpt_name + ) + checkpoint_full_path = checkpoint_parent_dir / checkpoint_dir + + # Format of Checkpoint Directory: + # /checkpoints/chkpt_name//chkpt_ checkpoint_full_path.mkdir(parents=True, exist_ok=True) return checkpoint_full_path @@ -81,6 +48,7 @@ def is_admin(client: SyftClient) -> bool: def create_checkpoint( + name: str, # Name of the checkpoint client: SyftClient, root_email: str | None = None, root_pwd: str | None = None, @@ -103,12 +71,9 @@ def create_checkpoint( if isinstance(migration_data, SyftError): raise SyftException(message=migration_data.message) - if not is_interpreter_jupyter(): - raise SyftException( - message="Checkpoint can only be created in Jupyter Notebook." - ) - - checkpoint_dir = create_checkpoint_dir(server_uid=client.id.to_string()) + checkpoint_dir = create_checkpoint_dir( + server_uid=client.id.to_string(), chkpt_name=name + ) migration_data.save( path=checkpoint_dir / "migration.blob", yaml_path=checkpoint_dir / "migration.yaml", @@ -116,18 +81,12 @@ def create_checkpoint( print(f"Checkpoint saved at: \n {checkpoint_dir}") -def last_checkpoint_path_for_nb(server_uid: str, nb_name: str = None) -> Path | None: - """Return the directory of the latest checkpoint for the given notebook.""" - nb_name = nb_name if nb_name else current_nbname().stem - checkpoint_dir = None - if len(nb_name.split("/")) > 1: - nb_name, checkpoint_dir = nb_name.split("/") +def last_checkpoint_path_for(server_uid: str, chkpt_name: str) -> Path | None: + """Return the directory of the latest checkpoint for the given name.""" - filename = nb_name.split(".ipynb")[0] - checkpoint_parent_dir = get_checkpoints_dir(server_uid, filename) - - if checkpoint_dir: - return checkpoint_parent_dir / checkpoint_dir + checkpoint_parent_dir = get_checkpoint_parent_dir( + server_uid=server_uid, chkpt_name=chkpt_name + ) checkpoint_dirs = [ d @@ -139,7 +98,7 @@ def last_checkpoint_path_for_nb(server_uid: str, nb_name: str = None) -> Path | ] if checkpoints_dirs_with_blob_entry: - print("Loading from the last checkpoint of the current notebook.") + print(f"Loading from the last checkpoint for: {chkpt_name}") return max(checkpoints_dirs_with_blob_entry, key=lambda d: d.stat().st_mtime) return None @@ -153,17 +112,13 @@ def get_registry_credentials() -> tuple[str, str]: def load_from_checkpoint( client: SyftClient, - prev_nb_filename: str | None = None, + name: str, root_email: str | None = None, root_password: str | None = None, registry_username: str | None = None, registry_password: str | None = None, - checkpoint_name: str | None = None, ) -> None: - """Load the last saved checkpoint for the given notebook state.""" - if prev_nb_filename is None: - print("Loading from the last checkpoint of the current notebook.") - prev_nb_filename = current_nbname().stem + """Load the last saved checkpoint for the given checkpoint state.""" root_email = "info@openmined.org" if root_email is None else root_email root_password = "changethis" if root_password is None else root_password @@ -173,12 +128,12 @@ def load_from_checkpoint( if is_admin(client) else client.login(email=root_email, password=root_password) ) - latest_checkpoint_dir = last_checkpoint_path_for_nb( - client.id.to_string(), prev_nb_filename + latest_checkpoint_dir = last_checkpoint_path_for( + server_uid=client.id.to_string(), chkpt_name=name ) if latest_checkpoint_dir is None: - print(f"No last checkpoint found for notebook: {prev_nb_filename}") + print(f"No last checkpoint found for : {name}") return print(f"Loading from checkpoint: {latest_checkpoint_dir}")