Skip to content

Commit

Permalink
Merge pull request #9331 from OpenMined/rasswanth/remove-name-input
Browse files Browse the repository at this point in the history
[WIP] Notebooks State - Misc Improvements
  • Loading branch information
rasswanth-s authored Oct 1, 2024
2 parents bc23abc + f7a2e20 commit 1c260e8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@
"metadata": {},
"outputs": [],
"source": [
"create_checkpoint(root_client)"
"create_checkpoint(name=\"000-start-and-config\", client=root_client)"
]
},
{
Expand All @@ -231,13 +231,6 @@
"source": [
"server.land()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -256,7 +249,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.11.10"
}
},
"nbformat": 4,
Expand Down
12 changes: 2 additions & 10 deletions notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
"outputs": [],
"source": [
"load_from_checkpoint(\n",
" prev_nb_filename=\"000-start-and-configure-server-and-admins\",\n",
" name=\"000-start-and-config\",\n",
" client=server.client,\n",
" root_email=ROOT_EMAIL,\n",
" root_password=ROOT_PASSWORD,\n",
Expand Down Expand Up @@ -357,14 +357,6 @@
"source": [
"server.land()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -383,7 +375,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.11.10"
}
},
"nbformat": 4,
Expand Down
97 changes: 26 additions & 71 deletions packages/syft/src/syft/util/test_helpers/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
# stdlib
import datetime
import json
import os
from pathlib import Path

# third party
import ipykernel

# syft absolute
from syft import SyftError
from syft import SyftException
from syft.client.client import SyftClient
from syft.service.user.user_roles import ServiceRole
from syft.util.util import get_root_data_path
from syft.util.util import is_interpreter_jupyter

# relative
from ...server.env import get_default_root_email
Expand All @@ -24,53 +19,25 @@
CHECKPOINT_DIR_PREFIX = "chkpt"


def get_notebook_name_from_pytest_env() -> str | None:
"""
Returns the notebook file name from the test environment variable 'PYTEST_CURRENT_TEST'.
If not available, returns None.
"""
pytest_current_test = os.environ.get("PYTEST_CURRENT_TEST", "")
# Split by "::" and return the first part, which is the file path
return os.path.basename(pytest_current_test.split("::")[0])


def current_nbname() -> Path:
"""Retrieve the current Jupyter notebook name."""
curr_kernel_file = Path(ipykernel.get_connection_file())
kernel_file = json.loads(curr_kernel_file.read_text())
nb_name = kernel_file.get("jupyter_session", "")
if not nb_name:
nb_name = get_notebook_name_from_pytest_env()
return Path(nb_name)


def root_checkpoint_path() -> Path:
return get_root_data_path() / CHECKPOINT_ROOT


def checkpoint_parent_dir(server_uid: str, nb_name: str | None = None) -> Path:
"""Return the checkpoint directory for the current notebook and server."""
if is_interpreter_jupyter:
nb_name = nb_name if nb_name else current_nbname().stem
return Path(f"{nb_name}/{server_uid}") if nb_name else Path(server_uid)
return Path(server_uid)


def get_checkpoints_dir(server_uid: str, nb_name: str) -> Path:
return root_checkpoint_path() / checkpoint_parent_dir(server_uid, nb_name)

def get_checkpoint_parent_dir(server_uid: str, chkpt_name: str) -> Path:
return root_checkpoint_path() / chkpt_name / server_uid

def get_checkpoint_dir(
server_uid: str, checkpoint_dir: str, nb_name: str | None = None
) -> Path:
return get_checkpoints_dir(server_uid, nb_name) / checkpoint_dir


def create_checkpoint_dir(server_uid: str) -> Path:
"""Create a checkpoint directory for the current notebook and server."""
def create_checkpoint_dir(server_uid: str, chkpt_name: str) -> Path:
"""Create a checkpoint directory by chkpt_name and server_uid."""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = f"{CHECKPOINT_DIR_PREFIX}_{timestamp}"
checkpoint_full_path = get_checkpoint_dir(server_uid, checkpoint_dir=checkpoint_dir)
checkpoint_parent_dir = get_checkpoint_parent_dir(
server_uid=server_uid, chkpt_name=chkpt_name
)
checkpoint_full_path = checkpoint_parent_dir / checkpoint_dir

# Format of Checkpoint Directory:
# <root_syft_dir>/checkpoints/chkpt_name/<server_uid>/chkpt_<timestamp>

checkpoint_full_path.mkdir(parents=True, exist_ok=True)
return checkpoint_full_path
Expand All @@ -81,6 +48,7 @@ def is_admin(client: SyftClient) -> bool:


def create_checkpoint(
name: str, # Name of the checkpoint
client: SyftClient,
root_email: str | None = None,
root_pwd: str | None = None,
Expand All @@ -103,31 +71,22 @@ def create_checkpoint(
if isinstance(migration_data, SyftError):
raise SyftException(message=migration_data.message)

if not is_interpreter_jupyter():
raise SyftException(
message="Checkpoint can only be created in Jupyter Notebook."
)

checkpoint_dir = create_checkpoint_dir(server_uid=client.id.to_string())
checkpoint_dir = create_checkpoint_dir(
server_uid=client.id.to_string(), chkpt_name=name
)
migration_data.save(
path=checkpoint_dir / "migration.blob",
yaml_path=checkpoint_dir / "migration.yaml",
)
print(f"Checkpoint saved at: \n {checkpoint_dir}")


def last_checkpoint_path_for_nb(server_uid: str, nb_name: str = None) -> Path | None:
"""Return the directory of the latest checkpoint for the given notebook."""
nb_name = nb_name if nb_name else current_nbname().stem
checkpoint_dir = None
if len(nb_name.split("/")) > 1:
nb_name, checkpoint_dir = nb_name.split("/")
def last_checkpoint_path_for(server_uid: str, chkpt_name: str) -> Path | None:
"""Return the directory of the latest checkpoint for the given name."""

filename = nb_name.split(".ipynb")[0]
checkpoint_parent_dir = get_checkpoints_dir(server_uid, filename)

if checkpoint_dir:
return checkpoint_parent_dir / checkpoint_dir
checkpoint_parent_dir = get_checkpoint_parent_dir(
server_uid=server_uid, chkpt_name=chkpt_name
)

checkpoint_dirs = [
d
Expand All @@ -139,7 +98,7 @@ def last_checkpoint_path_for_nb(server_uid: str, nb_name: str = None) -> Path |
]

if checkpoints_dirs_with_blob_entry:
print("Loading from the last checkpoint of the current notebook.")
print(f"Loading from the last checkpoint for: {chkpt_name}")
return max(checkpoints_dirs_with_blob_entry, key=lambda d: d.stat().st_mtime)

return None
Expand All @@ -153,17 +112,13 @@ def get_registry_credentials() -> tuple[str, str]:

def load_from_checkpoint(
client: SyftClient,
prev_nb_filename: str | None = None,
name: str,
root_email: str | None = None,
root_password: str | None = None,
registry_username: str | None = None,
registry_password: str | None = None,
checkpoint_name: str | None = None,
) -> None:
"""Load the last saved checkpoint for the given notebook state."""
if prev_nb_filename is None:
print("Loading from the last checkpoint of the current notebook.")
prev_nb_filename = current_nbname().stem
"""Load the last saved checkpoint for the given checkpoint state."""

root_email = "[email protected]" if root_email is None else root_email
root_password = "changethis" if root_password is None else root_password
Expand All @@ -173,12 +128,12 @@ def load_from_checkpoint(
if is_admin(client)
else client.login(email=root_email, password=root_password)
)
latest_checkpoint_dir = last_checkpoint_path_for_nb(
client.id.to_string(), prev_nb_filename
latest_checkpoint_dir = last_checkpoint_path_for(
server_uid=client.id.to_string(), chkpt_name=name
)

if latest_checkpoint_dir is None:
print(f"No last checkpoint found for notebook: {prev_nb_filename}")
print(f"No last checkpoint found for : {name}")
return

print(f"Loading from checkpoint: {latest_checkpoint_dir}")
Expand Down

0 comments on commit 1c260e8

Please sign in to comment.