From dfc0682e2214515cd7e3ef55d2aee24aa16edd11 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Thu, 12 Sep 2024 12:11:53 +0200 Subject: [PATCH 01/15] refactor user code status --- packages/syft/src/syft/__init__.py | 2 +- .../src/syft/service/code/status_service.py | 22 ++ .../syft/src/syft/service/code/user_code.py | 239 ++++++++---------- .../syft/service/code/user_code_service.py | 20 +- .../syft/src/syft/service/request/request.py | 10 +- packages/syft/src/syft/store/linked_obj.py | 19 +- .../service/sync/sync_resolve_single_test.py | 2 +- 7 files changed, 165 insertions(+), 149 deletions(-) diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index 62655f0aed1..95cc582e83d 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -53,7 +53,7 @@ from .service.api.api import api_endpoint from .service.api.api import api_endpoint_method from .service.api.api import create_new_api_endpoint as TwinAPIEndpoint -from .service.code.user_code import UserCodeStatus +from .service.code.user_code import UsercodeStatus from .service.code.user_code import syft_function from .service.code.user_code import syft_function_single_use from .service.data_subject import DataSubjectCreate as DataSubject diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index 1ffb70ebb6f..a30409915ca 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -12,6 +12,8 @@ from ...store.document_store import UIDPartitionKey from ...store.document_store_errors import StashException from ...types.result import as_result +from ...types.syft_object import PartialSyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.uid import UID from ..context import AuthedServiceContext from ..response import SyftSuccess @@ -45,6 +47,13 @@ def get_by_uid( return self.query_one(credentials=credentials, qks=qks).unwrap() +class CodeStatusUpdate(PartialSyftObject): + __canonical_name__ = "CodeStatusUpdate" + __version__ = SYFT_OBJECT_VERSION_1 + + id: UID + + @serializable(canonical_name="UserCodeStatusService", version=1) class UserCodeStatusService(AbstractService): store: DocumentStore @@ -65,6 +74,19 @@ def create( obj=status, ).unwrap() + @service_method( + path="code_status.update", + name="update", + roles=ADMIN_ROLE_LEVEL, + autosplat=["code_update"], + unwrap_on_success=False, + ) + def update( + self, context: AuthedServiceContext, code_update: CodeStatusUpdate + ) -> SyftSuccess: + res = self.status.update(context.credentials, code_update).unwrap() + return SyftSuccess(message="UserCode updated successfully", value=res) + @service_method( path="code_status.get_by_uid", name="get_by_uid", roles=GUEST_ROLE_LEVEL ) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 0921ea9e704..725eb6ac0da 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -132,15 +132,102 @@ class UserCodeStatus(Enum): def __hash__(self) -> int: return hash(self.value) +@serializable(canonical_name="UserCodeStatusDecision", version=1) +class ApprovalDecision(SyftObject): + + status: UserCodeStatus + reason: str | None = None + + @property + def non_empty_reason(self): + if self.reason == "": + return None + return self.reason + + @serializable() class UserCodeStatusCollection(SyncableSyftObject): + """Currently this is a class that implements a mixed bag of two statusses + The first status is for a level 0 Request, which does not have an explicit + status_dict. However, when we call .denied, or .approved, we compute the + status on the fly. + The second use case is for a level 2 Request, in this case we store the status + dict on the object and use it as is. + """ __canonical_name__ = "UserCodeStatusCollection" __version__ = SYFT_OBJECT_VERSION_1 __repr_attrs__ = ["approved", "status_dict"] - status_dict: dict[ServerIdentity, tuple[UserCodeStatus, str]] = {} + + # this is empty in the case of l0 + status_dict: dict[ServerIdentity, ApprovalDecision] = {} + user_code_link: LinkedObject + user_verify_key: SyftVerifyKey + + _was_requested_on_lowside: bool = False + + # ugly and buggy optimization, remove at some point + _has_output_read_permissions_cache: bool | None = None + + @property + def approved(self) -> bool: + # only use this on the client side, in this case we can use self.get_api instead + # of using the context + self.get_approved(None) + + def get_approved(self, context: AuthedServiceContext | None): + if self._was_requested_on_lowside: + return self._compute_status_l0(context) == UserCodeStatus.APPROVED + else: + return all(x == UserCodeStatus.APPROVED for x, _ in self.status_dict.values()) + + @property + def denied(self) -> bool: + # for denied we use the status dict both for level 0 and level 2 + return any([approval_dec.status == UserCodeStatus.DENIED for approval_dec in self.status_dict.values()]) + + def _compute_status_l0( + self, context: AuthedServiceContext | None = None + ) -> UserCodeStatus: + has_readable_outputs = self.has_readable_outputs(context) + + if self.denied: + if has_readable_outputs: + prompt_warning_message( + "This request already has results published to the data scientist. " + "They will still be able to access those results." + ) + return UserCodeStatus.DENIED + elif has_readable_outputs: + return UserCodeStatus.APPROVED + else: + return UserCodeStatus.PENDING + + def has_readable_outputs(self, context: AuthedServiceContext | None = None): + if context is None: + # Clientside + api = self._get_api() + if self._has_output_read_permissions_cache is None: + has_readable_outputs = api.output.has_output_read_permissions( + self.user_code_link.object_uid, self.user_verify_key + ) + self._has_output_read_permissions_cache = has_readable_outputs + return has_readable_outputs + else: + return self._has_output_read_permissions_cache + else: + # Serverside + return context.server.services.output.has_output_read_permissions( + context, self.user_code_link.object_uid, self.user_verify_key + ) + + @property + def denial_reason_l0(self): + denial_reasons = [x.non_empty_reason for x in self.status_dict.values() + if x.status == UserCodeStatus.DENIED and x.non_empty_reason is not None] + return next(iter(denial_reasons), "") def syft_get_diffs(self, ext_obj: Any) -> list[AttrDiff]: # relative @@ -184,16 +271,20 @@ def _repr_html_(self) -> str: def __repr_syft_nested__(self) -> str: string = "" + status_dict = for server_identity, (status, reason) in self.status_dict.items(): string += f"{server_identity.server_name}: {status}, {reason}
" return string - def get_status_message(self) -> str: - if self.approved: + def get_status_message(self, context: AuthedServiceContext) -> str: + if self.get_approved(context): return f"{type(self)} approved" denial_string = "" string = "" - for server_identity, (status, reason) in self.status_dict.items(): + + status_dict = self._refresh_cache(context) + + for server_identity, (status, reason) in status_dict.items(): denial_string += f"Code status on server '{server_identity.server_name}' is '{status}'. Reason: {reason}" if not reason.endswith("."): denial_string += "." @@ -205,46 +296,6 @@ def get_status_message(self) -> str: else: return f"{type(self)} Your code is waiting for approval. {string}" - @property - def approved(self) -> bool: - return all(x == UserCodeStatus.APPROVED for x, _ in self.status_dict.values()) - - @property - def denied(self) -> bool: - for status, _ in self.status_dict.values(): - if status == UserCodeStatus.DENIED: - return True - return False - - def for_user_context(self, context: AuthedServiceContext) -> UserCodeStatus: - if context.server.server_type == ServerType.ENCLAVE: - keys = {status for status, _ in self.status_dict.values()} - if len(keys) == 1 and UserCodeStatus.APPROVED in keys: - return UserCodeStatus.APPROVED - elif UserCodeStatus.PENDING in keys and UserCodeStatus.DENIED not in keys: - return UserCodeStatus.PENDING - elif UserCodeStatus.DENIED in keys: - return UserCodeStatus.DENIED - else: - raise Exception(f"Invalid types in {keys} for Code Submission") - - elif context.server.server_type == ServerType.DATASITE: - server_identity = ServerIdentity( - server_name=context.server.name, - server_id=context.server.id, - verify_key=context.server.signing_key.verify_key, - ) - if server_identity in self.status_dict: - return self.status_dict[server_identity][0] - else: - raise Exception( - f"Code Object does not contain {context.server.name} Datasite's data" - ) - else: - raise Exception( - f"Invalid Server Type for Code Submission:{context.server.server_type}" - ) - @as_result(SyftException) def mutate( self, @@ -301,8 +352,7 @@ class UserCode(SyncableSyftObject): nested_codes: dict[str, tuple[LinkedObject, dict]] | None = {} worker_pool_name: str | None = None origin_server_side_type: ServerSideType - l0_deny_reason: str | None = None - _has_output_read_permissions_cache: bool | None = None + # l0_deny_reason: str | None = None __table_coll_widths__ = [ "min-content", @@ -395,86 +445,14 @@ def user(self) -> UserView: api = self.get_api() return api.services.user.get_by_verify_key(self.user_verify_key) - def _compute_status_l0( - self, context: AuthedServiceContext | None = None - ) -> UserCodeStatusCollection: - if context is None: - # Clientside - api = self._get_api() - server_identity = ServerIdentity.from_api(api) - - if self._has_output_read_permissions_cache is None: - is_approved = api.output.has_output_read_permissions( - self.id, self.user_verify_key - ) - self._has_output_read_permissions_cache = is_approved - else: - is_approved = self._has_output_read_permissions_cache - else: - # Serverside - server_identity = ServerIdentity.from_server(context.server) - is_approved = context.server.services.output.has_output_read_permissions( - context, self.id, self.user_verify_key - ) - is_denied = self.l0_deny_reason is not None - - if is_denied: - if is_approved: - prompt_warning_message( - "This request already has results published to the data scientist. " - "They will still be able to access those results." - ) - message = self.l0_deny_reason - status = (UserCodeStatus.DENIED, message) - elif is_approved: - status = (UserCodeStatus.APPROVED, "") - else: - status = (UserCodeStatus.PENDING, "") - status_dict = {server_identity: status} - - return UserCodeStatusCollection( - status_dict=status_dict, - user_code_link=LinkedObject.from_obj(self), - ) - @property def status(self) -> UserCodeStatusCollection: - # Clientside only - - if self.is_l0_deployment: - if self.status_link is not None: - raise SyftException( - public_message="Encountered a low side UserCode object with a status_link." - ) - return self._compute_status_l0() - - if self.status_link is None: - raise SyftException( - public_message="This UserCode does not have a status. Please contact the Admin." - ) - res = self.status_link.resolve - return res + # only use this client side + return self.get_status(None) @as_result(SyftException) - def get_status(self, context: AuthedServiceContext) -> UserCodeStatusCollection: - if self.is_l0_deployment: - if self.status_link is not None: - raise SyftException( - public_message="Encountered a low side UserCode object with a status_link." - ) - return self._compute_status_l0(context) - - if self.status_link is None: - raise SyftException( - public_message="This UserCode does not have a status. Please contact the Admin." - ) - - return self.status_link.resolve_with_context(context).unwrap() - - @as_result(SyftException) - def is_status_approved(self, context: AuthedServiceContext) -> bool: - status = self.get_status(context).unwrap() - return status.approved + def get_status(self, context: AuthedServiceContext | None) -> UserCodeStatusCollection: + return self.status_link.resolve_dynamic(context, load_cached=True) @property def input_owners(self) -> list[str] | None: @@ -524,7 +502,7 @@ def input_policy(self) -> InputPolicy | None: def get_input_policy(self, context: AuthedServiceContext) -> InputPolicy | None: status = self.get_status(context).unwrap() - if status.approved or self.input_policy_type.has_safe_serde: + if status.get_approved(context) or self.input_policy_type.has_safe_serde: return self._get_input_policy() return None @@ -586,7 +564,7 @@ def input_policy(self, value: Any) -> None: # type: ignore def get_output_policy(self, context: AuthedServiceContext) -> OutputPolicy | None: status = self.get_status(context).unwrap() - if status.approved or self.output_policy_type.has_safe_serde: + if status.get_approved(context) or self.output_policy_type.has_safe_serde: return self._get_output_policy() return None @@ -1539,11 +1517,12 @@ def create_code_status(context: TransformContext) -> TransformContext: if context.output is None: return context - # Low side requests have a computed status - if context.server.server_side_type == ServerSideType.LOW_SIDE: - return context + # # Low side requests have a computed status + # if + # return context + + was_requested_on_lowside = context.server.server_side_type == ServerSideType.LOW_SIDE - input_keys = list(context.output["input_policy_init_kwargs"].keys()) code_link = LinkedObject.from_uid( context.output["id"], UserCode, @@ -1559,13 +1538,17 @@ def create_code_status(context: TransformContext) -> TransformContext: status = UserCodeStatusCollection( status_dict={server_identity: (UserCodeStatus.PENDING, "")}, user_code_link=code_link, + user_verify_key=context.credentials, + _was_requested_on_lowside=was_requested_on_lowside ) elif context.server.server_type == ServerType.ENCLAVE: + input_keys = list(context.output["input_policy_init_kwargs"].keys()) status_dict = {key: (UserCodeStatus.PENDING, "") for key in input_keys} status = UserCodeStatusCollection( status_dict=status_dict, user_code_link=code_link, + user_verify_key=context.credentials, ) else: raise NotImplementedError( diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 32863990a6a..1fc3ef8cc49 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -12,7 +12,6 @@ from ...types.errors import SyftException from ...types.result import Err from ...types.result import as_result -from ...types.syft_metaclass import Empty from ...types.twin_object import TwinObject from ...types.uid import UID from ..action.action_object import ActionObject @@ -149,14 +148,7 @@ def update( context: AuthedServiceContext, code_update: UserCodeUpdate, ) -> SyftSuccess: - code = self.stash.get_by_uid(context.credentials, code_update.id).unwrap() - # FIX: Check if this works (keep commented): - # self.stash.update(context.credentials, code).unwrap() - - if code_update.l0_deny_reason is not Empty: # type: ignore[comparison-overlap] - code.l0_deny_reason = code_update.l0_deny_reason - - updated_code = self.stash.update(context.credentials, code).unwrap() + updated_code = self.stash.update(context.credentials, code_update).unwrap() return SyftSuccess(message="UserCode updated successfully", value=updated_code) @service_method( @@ -360,7 +352,7 @@ def is_execution_allowed( output_policy: OutputPolicy | None, ) -> IsExecutionAllowedEnum: status = code.get_status(context).unwrap() - if not status.approved: + if not status.get_approved(context): return IsExecutionAllowedEnum.NOT_APPROVED elif self.has_code_permission(code, context) is HasCodePermissionEnum.DENIED: # TODO: Check enum above @@ -510,8 +502,10 @@ def _call( # code is from low side (L0 setup) status = code.get_status(context).unwrap() - if not status.approved: - raise SyftException(public_message=status.get_status_message()) + if not status.get_approved(context): + raise SyftException( + public_message=status.get_status_message(context) + ) output_policy_is_valid = False try: @@ -647,7 +641,7 @@ def store_execution_output( is_admin = context.role == ServiceRole.ADMIN - if not code.is_status_approved(context) and not is_admin: + if not code.get_status(context).get_approved(context) and not is_admin: raise SyftException(public_message="This UserCode is not approved") return code.store_execution_output( diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 1a492ea1d53..6c2c5d679f6 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -66,8 +66,10 @@ class RequestStatus(Enum): APPROVED = 2 @classmethod - def from_usercode_status(cls, status: UserCodeStatusCollection) -> "RequestStatus": - if status.approved: + def from_usercode_status( + cls, status: UserCodeStatusCollection, context: AuthedServiceContext + ) -> "RequestStatus": + if status.get_approved(context): return RequestStatus.APPROVED elif status.denied: return RequestStatus.REJECTED @@ -531,7 +533,7 @@ def get_status(self, context: AuthedServiceContext | None = None) -> RequestStat code_status = ( self.code.get_status(context) if context else self.code.status ) - return RequestStatus.from_usercode_status(code_status) + return RequestStatus.from_usercode_status(code_status, context) except Exception: # nosec # this breaks when coming from a user submitting a request # which tries to send an email to the admin and ends up here @@ -612,7 +614,7 @@ def deny(self, reason: str) -> SyftSuccess: "This request already has results published to the data scientist. " "They will still be able to access those results." ) - api.code.update(id=self.code_id, l0_deny_reason=reason) + api.code_status.update(id=self.code_id, l0_deny_reason=reason) return SyftSuccess(message=f"Request denied with reason: {reason}") return api.services.request.undo(uid=self.id, reason=reason) diff --git a/packages/syft/src/syft/store/linked_obj.py b/packages/syft/src/syft/store/linked_obj.py index 5d9c29c9d9d..448ef7aaa58 100644 --- a/packages/syft/src/syft/store/linked_obj.py +++ b/packages/syft/src/syft/store/linked_obj.py @@ -42,7 +42,12 @@ def __str__(self) -> str: @property def resolve(self) -> SyftObject: + return self._resolve() + + def _resolve(self, load_cached=False) -> SyftObject: api = None + if load_cached and self._resolve_cache is not None: + return self._resolve_cache try: # relative api = self.get_api() # raises @@ -53,15 +58,25 @@ def resolve(self) -> SyftObject: logger.error(">>> Failed to resolve object", type(api), e) raise e + def resolve_dynamic(self, context: ServerServiceContext | None, load_cached=False): + if context is not None: + return self.resolve_with_context(context, load_cached) + else: + return self._resolve(load_cached) + @as_result(SyftException) - def resolve_with_context(self, context: ServerServiceContext) -> Any: + def resolve_with_context( + self, context: ServerServiceContext, load_cached=False + ) -> Any: if context.server is None: raise ValueError(f"context {context}'s server is None") - return ( + res = ( context.server.get_service(self.service_type) .resolve_link(context=context, linked_obj=self) .unwrap() ) + self._resolve_cache = res + return res def update_with_context( self, context: ServerServiceContext | ChangeContext | Any, obj: Any diff --git a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py index 258a698dcd5..809a50e50a0 100644 --- a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py +++ b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py @@ -414,7 +414,7 @@ def compute() -> int: assert low_client.requests[0].status == RequestStatus.REJECTED # Un-deny. NOTE: not supported by current UX, this is just used to re-deny on high side - low_client.api.code.update(id=request_low.code_id, l0_deny_reason=None) + low_client.api.code_status.update(id=request_low.code_id, l0_deny_reason=None) assert low_client.requests[0].status == RequestStatus.PENDING # Sync request to high side From d2588c957d90278837a6d8764cb20ff3ab1d1a75 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Thu, 12 Sep 2024 12:15:58 +0200 Subject: [PATCH 02/15] update docs --- .../syft/src/syft/service/code/user_code.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 725eb6ac0da..371b9792486 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -149,11 +149,11 @@ def non_empty_reason(self): @serializable() class UserCodeStatusCollection(SyncableSyftObject): """Currently this is a class that implements a mixed bag of two statusses - The first status is for a level 0 Request, which does not have an explicit - status_dict. However, when we call .denied, or .approved, we compute the - status on the fly. + The first status is for a level 0 Request, which only uses the status dict + for denied decision. If there is no denied decision, it computes the status + by checking the backend for whether it has readable outputs. The second use case is for a level 2 Request, in this case we store the status - dict on the object and use it as is. + dict on the object and use it as is for both denied and approved status """ __canonical_name__ = "UserCodeStatusCollection" __version__ = SYFT_OBJECT_VERSION_1 @@ -191,7 +191,7 @@ def denied(self) -> bool: def _compute_status_l0( self, context: AuthedServiceContext | None = None ) -> UserCodeStatus: - has_readable_outputs = self.has_readable_outputs(context) + has_readable_outputs = self._has_readable_outputs(context) if self.denied: if has_readable_outputs: @@ -204,8 +204,8 @@ def _compute_status_l0( return UserCodeStatus.APPROVED else: return UserCodeStatus.PENDING - - def has_readable_outputs(self, context: AuthedServiceContext | None = None): + + def _has_readable_outputs(self, context: AuthedServiceContext | None = None): if context is None: # Clientside api = self._get_api() @@ -225,7 +225,7 @@ def has_readable_outputs(self, context: AuthedServiceContext | None = None): @property def denial_reason_l0(self): - denial_reasons = [x.non_empty_reason for x in self.status_dict.values() + denial_reasons = [x.non_empty_reason for x in self.status_dict.values() if x.status == UserCodeStatus.DENIED and x.non_empty_reason is not None] return next(iter(denial_reasons), "") @@ -271,7 +271,6 @@ def _repr_html_(self) -> str: def __repr_syft_nested__(self) -> str: string = "" - status_dict = for server_identity, (status, reason) in self.status_dict.items(): string += f"{server_identity.server_name}: {status}, {reason}
" return string @@ -1518,7 +1517,7 @@ def create_code_status(context: TransformContext) -> TransformContext: return context # # Low side requests have a computed status - # if + # if # return context was_requested_on_lowside = context.server.server_side_type == ServerSideType.LOW_SIDE From cece25863e46e947ae39b1a737b5cebc10dbcd39 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Thu, 12 Sep 2024 15:00:57 +0200 Subject: [PATCH 03/15] rewrite reprs --- .../syft/src/syft/service/code/user_code.py | 134 +++++++++++------- .../syft/service/code/user_code_service.py | 4 +- packages/syft/src/syft/service/context.py | 5 + .../syft/src/syft/service/request/request.py | 6 +- .../src/syft/service/sync/sync_service.py | 4 +- .../syft/util/notebook_ui/components/sync.py | 8 +- 6 files changed, 101 insertions(+), 60 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 371b9792486..1fc92087877 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -132,9 +132,9 @@ class UserCodeStatus(Enum): def __hash__(self) -> int: return hash(self.value) + @serializable(canonical_name="UserCodeStatusDecision", version=1) class ApprovalDecision(SyftObject): - status: UserCodeStatus reason: str | None = None @@ -145,16 +145,16 @@ def non_empty_reason(self): return self.reason - @serializable() class UserCodeStatusCollection(SyncableSyftObject): """Currently this is a class that implements a mixed bag of two statusses The first status is for a level 0 Request, which only uses the status dict for denied decision. If there is no denied decision, it computes the status - by checking the backend for whether it has readable outputs. + by checking the backend for whether it has readable outputs. The second use case is for a level 2 Request, in this case we store the status dict on the object and use it as is for both denied and approved status """ + __canonical_name__ = "UserCodeStatusCollection" __version__ = SYFT_OBJECT_VERSION_1 @@ -169,7 +169,7 @@ class UserCodeStatusCollection(SyncableSyftObject): _was_requested_on_lowside: bool = False # ugly and buggy optimization, remove at some point - _has_output_read_permissions_cache: bool | None = None + _has_readable_outputs_cache: bool | None = None @property def approved(self) -> bool: @@ -177,20 +177,34 @@ def approved(self) -> bool: # of using the context self.get_approved(None) - def get_approved(self, context: AuthedServiceContext | None): + def get_approved(self, context: AuthedServiceContext | None) -> bool: + return self._compute_status(context) == UserCodeStatus.APPROVED + + def _compute_status( + self, context: AuthedServiceContext | None = None + ) -> UserCodeStatus: if self._was_requested_on_lowside: - return self._compute_status_l0(context) == UserCodeStatus.APPROVED + return self._compute_status_l0(context) else: - return all(x == UserCodeStatus.APPROVED for x, _ in self.status_dict.values()) + return self._compute_status_l2() @property def denied(self) -> bool: # for denied we use the status dict both for level 0 and level 2 - return any([approval_dec.status == UserCodeStatus.DENIED for approval_dec in self.status_dict.values()]) + return any( + [ + approval_dec.status == UserCodeStatus.DENIED + for approval_dec in self.status_dict.values() + ] + ) def _compute_status_l0( self, context: AuthedServiceContext | None = None ) -> UserCodeStatus: + # for l0, if denied in status dict, its denied + # if not, and it has readable outputs, its approved, + # else pending + has_readable_outputs = self._has_readable_outputs(context) if self.denied: @@ -204,19 +218,33 @@ def _compute_status_l0( return UserCodeStatus.APPROVED else: return UserCodeStatus.PENDING - + + def _compute_status_l2(self) -> UserCodeStatus: + any_denied = any( + x == UserCodeStatus.DENIED for x, _ in self.status_dict.values() + ) + all_approved = all( + x == UserCodeStatus.APPROVED for x, _ in self.status_dict.values() + ) + if any_denied: + return UserCodeStatus.DENIED + elif all_approved: + return UserCodeStatus.APPROVED + else: + return UserCodeStatus.PENDING + def _has_readable_outputs(self, context: AuthedServiceContext | None = None): if context is None: # Clientside api = self._get_api() - if self._has_output_read_permissions_cache is None: + if self._has_readable_outputs_cache is None: has_readable_outputs = api.output.has_output_read_permissions( self.user_code_link.object_uid, self.user_verify_key ) - self._has_output_read_permissions_cache = has_readable_outputs + self._has_readable_outputs_cache = has_readable_outputs return has_readable_outputs else: - return self._has_output_read_permissions_cache + return self._has_readable_outputs_cache else: # Serverside return context.server.services.output.has_output_read_permissions( @@ -224,9 +252,12 @@ def _has_readable_outputs(self, context: AuthedServiceContext | None = None): ) @property - def denial_reason_l0(self): - denial_reasons = [x.non_empty_reason for x in self.status_dict.values() - if x.status == UserCodeStatus.DENIED and x.non_empty_reason is not None] + def first_denial_reason(self) -> str: + denial_reasons = [ + x.non_empty_reason + for x in self.status_dict.values() + if x.status == UserCodeStatus.DENIED and x.non_empty_reason is not None + ] return next(iter(denial_reasons), "") def syft_get_diffs(self, ext_obj: Any) -> list[AttrDiff]: @@ -255,41 +286,39 @@ def _repr_html_(self) -> str:

User Code Status

""" - for server_identity, (status, reason) in self.status_dict.items(): + for server_identity, approval_decision in self.status_dict.items(): server_name_str = f"{server_identity.server_name}" uid_str = f"{server_identity.server_id}" - status_str = f"{status.value}" + status_str = f"{approval_decision.status.value}" string += f""" • UID: {uid_str}  Server name: {server_name_str}  Status: {status_str}; - Reason: {reason} + Reason: {approval_decision.reason}
""" string += "

" return string def __repr_syft_nested__(self) -> str: - string = "" - for server_identity, (status, reason) in self.status_dict.items(): - string += f"{server_identity.server_name}: {status}, {reason}
" - return string + # this currently assumes that there is only one status + status_str = self._compute_status().value + + if self.denied: + status_str = f"{status_str}: self.first_denial_reason" + return status_str - def get_status_message(self, context: AuthedServiceContext) -> str: + def get_status_message_l2(self, context: AuthedServiceContext) -> str: if self.get_approved(context): return f"{type(self)} approved" denial_string = "" string = "" - status_dict = self._refresh_cache(context) - - for server_identity, (status, reason) in status_dict.items(): - denial_string += f"Code status on server '{server_identity.server_name}' is '{status}'. Reason: {reason}" - if not reason.endswith("."): + for server_identity, approval_decision in self.status_dict.items(): + denial_string += f"Code status on server '{server_identity.server_name}' is '{approval_decision.status}'. Reason: {approval_decision.reason}" + if not approval_decision.reason.endswith("."): denial_string += "." - string += ( - f"Code status on server '{server_identity.server_name}' is '{status}'." - ) + string += f"Code status on server '{server_identity.server_name}' is '{approval_decision.status}'." if self.denied: return f"{type(self)} Your code cannot be run: {denial_string}" else: @@ -298,7 +327,7 @@ def get_status_message(self, context: AuthedServiceContext) -> str: @as_result(SyftException) def mutate( self, - value: tuple[UserCodeStatus, str], + value: ApprovalDecision, server_name: str, server_id: UID, verify_key: SyftVerifyKey, @@ -411,14 +440,14 @@ def __setattr__(self, key: str, value: Any) -> None: return super().__setattr__(key, value) def _coll_repr_(self) -> dict[str, Any]: - status = [status for status, _ in self.status.status_dict.values()][0].value - if status == UserCodeStatus.PENDING.value: + status = self.status._compute_status() + if status == UserCodeStatus.PENDING: badge_color = "badge-purple" - elif status == UserCodeStatus.APPROVED.value: + elif status == UserCodeStatus.APPROVED: badge_color = "badge-green" else: badge_color = "badge-red" - status_badge = {"value": status, "type": badge_color} + status_badge = {"value": status.value, "type": badge_color} return { "Input Policy": self.input_policy_type.__canonical_name__, "Output Policy": self.output_policy_type.__canonical_name__, @@ -450,7 +479,9 @@ def status(self) -> UserCodeStatusCollection: return self.get_status(None) @as_result(SyftException) - def get_status(self, context: AuthedServiceContext | None) -> UserCodeStatusCollection: + def get_status( + self, context: AuthedServiceContext | None + ) -> UserCodeStatusCollection: return self.status_link.resolve_dynamic(context, load_cached=True) @property @@ -485,13 +516,8 @@ def output_readers(self) -> list[SyftVerifyKey] | None: return None @property - def code_status(self) -> list: - status_list = [] - for server_view, (status, _) in self.status.status_dict.items(): - status_list.append( - f"Server: {server_view.server_name}, Status: {status.value}", - ) - return status_list + def code_status_str(self) -> str: + return f"Status: {self.status._compute_status().value}" @property def input_policy(self) -> InputPolicy | None: @@ -859,7 +885,7 @@ def _inner_repr(self, level: int = 0) -> str: id: UID = {self.id} service_func_name: str = {self.service_func_name} shareholders: list = {self.input_owners} - status: list = {self.code_status} + status: str = {self.code_status_str} {constants_str} {shared_with_line} inputs: dict = {inputs_str} @@ -914,7 +940,7 @@ def _ipython_display_(self, level: int = 0) -> None:

{tabs}id: UID = {self.id}

{tabs}service_func_name: str = {self.service_func_name}

{tabs}shareholders: list = {self.input_owners}

-

{tabs}status: list = {self.code_status}

+

{tabs}status: str = {self.code_status_str}

{tabs}{constants_str} {tabs}{shared_with_line}

{tabs}inputs: dict =

{self._inputs_json}

@@ -1517,10 +1543,12 @@ def create_code_status(context: TransformContext) -> TransformContext: return context # # Low side requests have a computed status - # if + # if # return context - was_requested_on_lowside = context.server.server_side_type == ServerSideType.LOW_SIDE + was_requested_on_lowside = ( + context.server.server_side_type == ServerSideType.LOW_SIDE + ) code_link = LinkedObject.from_uid( context.output["id"], @@ -1535,15 +1563,19 @@ def create_code_status(context: TransformContext) -> TransformContext: verify_key=context.server.signing_key.verify_key, ) status = UserCodeStatusCollection( - status_dict={server_identity: (UserCodeStatus.PENDING, "")}, + status_dict={ + server_identity: ApprovalDecision(status=UserCodeStatus.PENDING) + }, user_code_link=code_link, user_verify_key=context.credentials, - _was_requested_on_lowside=was_requested_on_lowside + _was_requested_on_lowside=was_requested_on_lowside, ) elif context.server.server_type == ServerType.ENCLAVE: input_keys = list(context.output["input_policy_init_kwargs"].keys()) - status_dict = {key: (UserCodeStatus.PENDING, "") for key in input_keys} + status_dict = { + key: ApprovalDecision(status=UserCodeStatus.PENDING) for key in input_keys + } status = UserCodeStatusCollection( status_dict=status_dict, user_code_link=code_link, diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 1fc3ef8cc49..f506572ecd9 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -502,7 +502,9 @@ def _call( # code is from low side (L0 setup) status = code.get_status(context).unwrap() - if not status.get_approved(context): + if context.server_allows_execution_for_ds and not status.get_approved( + context + ): raise SyftException( public_message=status.get_status_message(context) ) diff --git a/packages/syft/src/syft/service/context.py b/packages/syft/src/syft/service/context.py index 6e4037719f6..4ce07c67982 100644 --- a/packages/syft/src/syft/service/context.py +++ b/packages/syft/src/syft/service/context.py @@ -59,6 +59,11 @@ def is_l0_lowside(self) -> bool: """Returns True if this is a low side of a Level 0 deployment""" return self.server.server_side_type == ServerSideType.LOW_SIDE + @property + def server_allows_execution_for_ds(self) -> bool: + """Returns True if this is a low side of a Level 0 deployment""" + return not self.is_l0_lowside + def as_root_context(self) -> Self: return AuthedServiceContext( credentials=self.server.verify_key, diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 6c2c5d679f6..d6bc12a8ef5 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -42,6 +42,7 @@ from ..action.action_object import ActionObject from ..action.action_store import ActionObjectPermission from ..action.action_store import ActionPermission +from ..code.user_code import ApprovalDecision from ..code.user_code import UserCode from ..code.user_code import UserCodeStatus from ..code.user_code import UserCodeStatusCollection @@ -1427,8 +1428,11 @@ def mutate( undo: bool, ) -> UserCodeStatusCollection: reason: str = context.extra_kwargs.get("reason", "") + ApprovalDecision return status.mutate( - value=(UserCodeStatus.DENIED if undo else self.value, reason), + value=ApprovalDecision( + decision=UserCodeStatus.DENIED if undo else self.value, reason=reason + ), server_name=context.server.name, server_id=context.server.id, verify_key=context.server.signing_key.verify_key, diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index 00393543636..47a91c5f022 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -101,9 +101,9 @@ def transform_item( if isinstance(item, UserCodeStatusCollection): identity = ServerIdentity.from_server(context.server) res = {} - for key in item.status_dict.keys(): + for approval_decision in item.status_dict.values(): # todo, check if they are actually only two servers - res[identity] = item.status_dict[key] + res[identity] = approval_decision item.status_dict = res self.set_obj_ids(context, item) diff --git a/packages/syft/src/syft/util/notebook_ui/components/sync.py b/packages/syft/src/syft/util/notebook_ui/components/sync.py index 94e54c60aed..a0fd095e81d 100644 --- a/packages/syft/src/syft/util/notebook_ui/components/sync.py +++ b/packages/syft/src/syft/util/notebook_ui/components/sync.py @@ -97,12 +97,10 @@ def get_status_str(self) -> str: return f"Status: {self.object.status.value}" elif isinstance(self.object, Request): code = self.object.code - statusses = list(code.status.status_dict.values()) - if len(statusses) != 1: + approval_decisions = list(code.status.status_dict.values()) + if len(approval_decisions) != 1: raise ValueError("Request code should have exactly one status") - status_tuple = statusses[0] - status, _ = status_tuple - return status.value + return approval_decisions[0].status.value return "" # type: ignore def get_updated_by(self) -> str: From 3e3c459ecf4173987ad788373ef729d51ca792aa Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Thu, 12 Sep 2024 20:54:02 +0200 Subject: [PATCH 04/15] fix --- packages/syft/src/syft/__init__.py | 2 +- .../src/syft/protocol/protocol_version.json | 25 +++++ .../syft/src/syft/service/code/user_code.py | 100 +++++++++++++++++- 3 files changed, 123 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index 95cc582e83d..62655f0aed1 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -53,7 +53,7 @@ from .service.api.api import api_endpoint from .service.api.api import api_endpoint_method from .service.api.api import create_new_api_endpoint as TwinAPIEndpoint -from .service.code.user_code import UsercodeStatus +from .service.code.user_code import UserCodeStatus from .service.code.user_code import syft_function from .service.code.user_code import syft_function_single_use from .service.data_subject import DataSubjectCreate as DataSubject diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 5f9f6a8fab1..0514e1b98d7 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -1,5 +1,30 @@ { "1": { "release_name": "0.9.1.json" + }, + "dev": { + "object_versions": { + "ApprovalDecision": { + "1": { + "version": 1, + "hash": "ecce7c6e01af68b0c0a73605f0c2226917f0784ecce69e9f64ce004b243252d4", + "action": "add" + } + }, + "UserCodeStatusCollection": { + "2": { + "version": 2, + "hash": "aacbdcc19141d96914ab10b6c3f9f4684fb3f71d405254df70602655539044c7", + "action": "add" + } + }, + "UserCode": { + "2": { + "version": 2, + "hash": "c127aa856b208f06f50131cd114a910b5b147e252efc9c962f0fe424a1edd264", + "action": "add" + } + } + } } } diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 1fc92087877..f4cb28ce14e 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -133,11 +133,14 @@ def __hash__(self) -> int: return hash(self.value) -@serializable(canonical_name="UserCodeStatusDecision", version=1) +@serializable() class ApprovalDecision(SyftObject): status: UserCodeStatus reason: str | None = None + __canonical_name__ = "ApprovalDecision" + __version__ = 1 + @property def non_empty_reason(self): if self.reason == "": @@ -146,7 +149,7 @@ def non_empty_reason(self): @serializable() -class UserCodeStatusCollection(SyncableSyftObject): +class UserCodeStatusCollectionV1(SyncableSyftObject): """Currently this is a class that implements a mixed bag of two statusses The first status is for a level 0 Request, which only uses the status dict for denied decision. If there is no denied decision, it computes the status @@ -160,6 +163,26 @@ class UserCodeStatusCollection(SyncableSyftObject): __repr_attrs__ = ["approved", "status_dict"] + # this is empty in the case of l0 + status_dict: dict[ServerIdentity, tuple[UserCodeStatus, str]] = {} + + user_code_link: LinkedObject + +@serializable() +class UserCodeStatusCollection(SyncableSyftObject): + """Currently this is a class that implements a mixed bag of two statusses + The first status is for a level 0 Request, which only uses the status dict + for denied decision. If there is no denied decision, it computes the status + by checking the backend for whether it has readable outputs. + The second use case is for a level 2 Request, in this case we store the status + dict on the object and use it as is for both denied and approved status + """ + + __canonical_name__ = "UserCodeStatusCollection" + __version__ = SYFT_OBJECT_VERSION_2 + + __repr_attrs__ = ["approved", "status_dict"] + # this is empty in the case of l0 status_dict: dict[ServerIdentity, ApprovalDecision] = {} @@ -350,7 +373,7 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: @serializable() -class UserCode(SyncableSyftObject): +class UserCodeV1(SyncableSyftObject): # version __canonical_name__ = "UserCode" __version__ = SYFT_OBJECT_VERSION_1 @@ -377,6 +400,77 @@ class UserCode(SyncableSyftObject): # tracks if the code calls datasite.something, variable is set during parsing uses_datasite: bool = False + nested_codes: dict[str, tuple[LinkedObject, dict]] | None = {} + worker_pool_name: str | None = None + origin_server_side_type: ServerSideType + l0_deny_reason: str | None = None + _has_output_read_permissions_cache: bool | None = None + + __table_coll_widths__ = [ + "min-content", + "auto", + "auto", + "auto", + "auto", + "auto", + "auto", + "auto", + ] + + __attr_searchable__: ClassVar[list[str]] = [ + "user_verify_key", + "service_func_name", + "code_hash", + ] + __attr_unique__: ClassVar[list[str]] = [] + __repr_attrs__: ClassVar[list[str]] = [ + "service_func_name", + "input_owners", + "code_status", + "worker_pool_name", + "l0_deny_reason", + "raw_code", + ] + + __exclude_sync_diff_attrs__: ClassVar[list[str]] = [ + "server_uid", + "code_status", + "input_policy_type", + "input_policy_init_kwargs", + "input_policy_state", + "output_policy_type", + "output_policy_init_kwargs", + "output_policy_state", + ] + +@serializable() +class UserCode(SyncableSyftObject): + # version + __canonical_name__ = "UserCode" + __version__ = SYFT_OBJECT_VERSION_2 + + id: UID + server_uid: UID | None = None + user_verify_key: SyftVerifyKey + raw_code: str + input_policy_type: type[InputPolicy] | UserPolicy + input_policy_init_kwargs: dict[Any, Any] | None = None + input_policy_state: bytes = b"" + output_policy_type: type[OutputPolicy] | UserPolicy + output_policy_init_kwargs: dict[Any, Any] | None = None + output_policy_state: bytes = b"" + parsed_code: str + service_func_name: str + unique_func_name: str + user_unique_func_name: str + code_hash: str + signature: inspect.Signature + status_link: LinkedObject | None = None + input_kwargs: list[str] + submit_time: DateTime | None = None + # tracks if the code calls datasite.something, variable is set during parsing + uses_datasite: bool = False + nested_codes: dict[str, tuple[LinkedObject, dict]] | None = {} worker_pool_name: str | None = None origin_server_side_type: ServerSideType From a4b64b5b9137900951e396ad0f4d506ae95f1738 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Tue, 17 Sep 2024 17:11:15 +0200 Subject: [PATCH 05/15] fix tests --- .../f6c08d0ae735435582a74331d6b9984e.json | 30 ++++++++++ .../src/syft/service/code/status_service.py | 14 ++++- .../syft/src/syft/service/code/user_code.py | 56 +++++++++++-------- .../syft/service/code/user_code_service.py | 14 +++-- .../syft/src/syft/service/request/request.py | 24 +++++--- .../syft/src/syft/service/sync/diff_state.py | 14 ++--- .../src/syft/service/sync/resolve_widget.py | 17 ++++-- packages/syft/src/syft/store/linked_obj.py | 12 ++-- .../service/sync/sync_resolve_single_test.py | 16 +++++- 9 files changed, 139 insertions(+), 58 deletions(-) create mode 100644 packages/syft/src/syft/protocol/f6c08d0ae735435582a74331d6b9984e.json diff --git a/packages/syft/src/syft/protocol/f6c08d0ae735435582a74331d6b9984e.json b/packages/syft/src/syft/protocol/f6c08d0ae735435582a74331d6b9984e.json new file mode 100644 index 00000000000..0514e1b98d7 --- /dev/null +++ b/packages/syft/src/syft/protocol/f6c08d0ae735435582a74331d6b9984e.json @@ -0,0 +1,30 @@ +{ + "1": { + "release_name": "0.9.1.json" + }, + "dev": { + "object_versions": { + "ApprovalDecision": { + "1": { + "version": 1, + "hash": "ecce7c6e01af68b0c0a73605f0c2226917f0784ecce69e9f64ce004b243252d4", + "action": "add" + } + }, + "UserCodeStatusCollection": { + "2": { + "version": 2, + "hash": "aacbdcc19141d96914ab10b6c3f9f4684fb3f71d405254df70602655539044c7", + "action": "add" + } + }, + "UserCode": { + "2": { + "version": 2, + "hash": "c127aa856b208f06f50131cd114a910b5b147e252efc9c962f0fe424a1edd264", + "action": "add" + } + } + } + } +} diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index a30409915ca..f5984b030a1 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -3,6 +3,7 @@ # third party # relative +from ...client.api import ServerIdentity from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.document_store import DocumentStore @@ -22,6 +23,7 @@ from ..service import service_method from ..user.user_roles import ADMIN_ROLE_LEVEL from ..user.user_roles import GUEST_ROLE_LEVEL +from .user_code import ApprovalDecision from .user_code import UserCodeStatusCollection @@ -52,6 +54,7 @@ class CodeStatusUpdate(PartialSyftObject): __version__ = SYFT_OBJECT_VERSION_1 id: UID + decision: ApprovalDecision @serializable(canonical_name="UserCodeStatusService", version=1) @@ -69,10 +72,11 @@ def create( context: AuthedServiceContext, status: UserCodeStatusCollection, ) -> UserCodeStatusCollection: - return self.stash.set( + res = self.stash.set( credentials=context.credentials, obj=status, ).unwrap() + return res @service_method( path="code_status.update", @@ -84,7 +88,13 @@ def create( def update( self, context: AuthedServiceContext, code_update: CodeStatusUpdate ) -> SyftSuccess: - res = self.status.update(context.credentials, code_update).unwrap() + existing_status = self.stash.get_by_uid( + context.credentials, uid=code_update.id + ).unwrap() + server_identity = ServerIdentity.from_server(context.server) + existing_status.status_dict[server_identity] = code_update.decision + + res = self.stash.update(context.credentials, existing_status).unwrap() return SyftSuccess(message="UserCode updated successfully", value=res) @service_method( diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index f4cb28ce14e..73a2b7a0c00 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -142,7 +142,8 @@ class ApprovalDecision(SyftObject): __version__ = 1 @property - def non_empty_reason(self): + def reason_or_none(self) -> str | None: + # TODO: move to class creation if self.reason == "": return None return self.reason @@ -168,6 +169,7 @@ class UserCodeStatusCollectionV1(SyncableSyftObject): user_code_link: LinkedObject + @serializable() class UserCodeStatusCollection(SyncableSyftObject): """Currently this is a class that implements a mixed bag of two statusses @@ -189,7 +191,7 @@ class UserCodeStatusCollection(SyncableSyftObject): user_code_link: LinkedObject user_verify_key: SyftVerifyKey - _was_requested_on_lowside: bool = False + was_requested_on_lowside: bool = False # ugly and buggy optimization, remove at some point _has_readable_outputs_cache: bool | None = None @@ -198,15 +200,15 @@ class UserCodeStatusCollection(SyncableSyftObject): def approved(self) -> bool: # only use this on the client side, in this case we can use self.get_api instead # of using the context - self.get_approved(None) + return self.get_is_approved(None) - def get_approved(self, context: AuthedServiceContext | None) -> bool: + def get_is_approved(self, context: AuthedServiceContext | None) -> bool: return self._compute_status(context) == UserCodeStatus.APPROVED def _compute_status( self, context: AuthedServiceContext | None = None ) -> UserCodeStatus: - if self._was_requested_on_lowside: + if self.was_requested_on_lowside: return self._compute_status_l0(context) else: return self._compute_status_l2() @@ -215,10 +217,8 @@ def _compute_status( def denied(self) -> bool: # for denied we use the status dict both for level 0 and level 2 return any( - [ - approval_dec.status == UserCodeStatus.DENIED - for approval_dec in self.status_dict.values() - ] + approval_dec.status == UserCodeStatus.DENIED + for approval_dec in self.status_dict.values() ) def _compute_status_l0( @@ -244,10 +244,12 @@ def _compute_status_l0( def _compute_status_l2(self) -> UserCodeStatus: any_denied = any( - x == UserCodeStatus.DENIED for x, _ in self.status_dict.values() + approval_dec.status == UserCodeStatus.DENIED + for approval_dec in self.status_dict.values() ) all_approved = all( - x == UserCodeStatus.APPROVED for x, _ in self.status_dict.values() + approval_dec.status == UserCodeStatus.APPROVED + for approval_dec in self.status_dict.values() ) if any_denied: return UserCodeStatus.DENIED @@ -256,7 +258,9 @@ def _compute_status_l2(self) -> UserCodeStatus: else: return UserCodeStatus.PENDING - def _has_readable_outputs(self, context: AuthedServiceContext | None = None): + def _has_readable_outputs( + self, context: AuthedServiceContext | None = None + ) -> bool: if context is None: # Clientside api = self._get_api() @@ -277,9 +281,9 @@ def _has_readable_outputs(self, context: AuthedServiceContext | None = None): @property def first_denial_reason(self) -> str: denial_reasons = [ - x.non_empty_reason + x.reason_or_none for x in self.status_dict.values() - if x.status == UserCodeStatus.DENIED and x.non_empty_reason is not None + if x.status == UserCodeStatus.DENIED and x.reason_or_none is not None ] return next(iter(denial_reasons), "") @@ -332,14 +336,17 @@ def __repr_syft_nested__(self) -> str: return status_str def get_status_message_l2(self, context: AuthedServiceContext) -> str: - if self.get_approved(context): + if self.get_is_approved(context): return f"{type(self)} approved" denial_string = "" string = "" for server_identity, approval_decision in self.status_dict.items(): - denial_string += f"Code status on server '{server_identity.server_name}' is '{approval_decision.status}'. Reason: {approval_decision.reason}" - if not approval_decision.reason.endswith("."): + denial_string += ( + f"Code status on server '{server_identity.server_name}' is '{approval_decision.status}'." + f" Reason: {approval_decision.reason}" + ) + if approval_decision.reason and not approval_decision.reason.endswith("."): # type: ignore denial_string += "." string += f"Code status on server '{server_identity.server_name}' is '{approval_decision.status}'." if self.denied: @@ -443,6 +450,7 @@ class UserCodeV1(SyncableSyftObject): "output_policy_state", ] + @serializable() class UserCode(SyncableSyftObject): # version @@ -465,7 +473,7 @@ class UserCode(SyncableSyftObject): user_unique_func_name: str code_hash: str signature: inspect.Signature - status_link: LinkedObject | None = None + status_link: LinkedObject input_kwargs: list[str] submit_time: DateTime | None = None # tracks if the code calls datasite.something, variable is set during parsing @@ -496,9 +504,9 @@ class UserCode(SyncableSyftObject): __repr_attrs__: ClassVar[list[str]] = [ "service_func_name", "input_owners", - "code_status", + "status", "worker_pool_name", - "l0_deny_reason", + # "l0_deny_reason", "raw_code", ] @@ -570,7 +578,7 @@ def user(self) -> UserView: @property def status(self) -> UserCodeStatusCollection: # only use this client side - return self.get_status(None) + return self.get_status(None).unwrap() @as_result(SyftException) def get_status( @@ -621,7 +629,7 @@ def input_policy(self) -> InputPolicy | None: def get_input_policy(self, context: AuthedServiceContext) -> InputPolicy | None: status = self.get_status(context).unwrap() - if status.get_approved(context) or self.input_policy_type.has_safe_serde: + if status.get_is_approved(context) or self.input_policy_type.has_safe_serde: return self._get_input_policy() return None @@ -683,7 +691,7 @@ def input_policy(self, value: Any) -> None: # type: ignore def get_output_policy(self, context: AuthedServiceContext) -> OutputPolicy | None: status = self.get_status(context).unwrap() - if status.get_approved(context) or self.output_policy_type.has_safe_serde: + if status.get_is_approved(context) or self.output_policy_type.has_safe_serde: return self._get_output_policy() return None @@ -1662,7 +1670,7 @@ def create_code_status(context: TransformContext) -> TransformContext: }, user_code_link=code_link, user_verify_key=context.credentials, - _was_requested_on_lowside=was_requested_on_lowside, + was_requested_on_lowside=was_requested_on_lowside, ) elif context.server.server_type == ServerType.ENCLAVE: diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index f506572ecd9..d15aa4cc8bf 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -352,7 +352,7 @@ def is_execution_allowed( output_policy: OutputPolicy | None, ) -> IsExecutionAllowedEnum: status = code.get_status(context).unwrap() - if not status.get_approved(context): + if not status.get_is_approved(context): return IsExecutionAllowedEnum.NOT_APPROVED elif self.has_code_permission(code, context) is HasCodePermissionEnum.DENIED: # TODO: Check enum above @@ -502,11 +502,12 @@ def _call( # code is from low side (L0 setup) status = code.get_status(context).unwrap() - if context.server_allows_execution_for_ds and not status.get_approved( - context + if ( + context.server_allows_execution_for_ds + and not status.get_is_approved(context) ): raise SyftException( - public_message=status.get_status_message(context) + public_message=status.get_status_message_l2(context) ) output_policy_is_valid = False @@ -643,7 +644,10 @@ def store_execution_output( is_admin = context.role == ServiceRole.ADMIN - if not code.get_status(context).get_approved(context) and not is_admin: + if ( + not code.get_status(context).unwrap().get_is_approved(context) + and not is_admin + ): raise SyftException(public_message="This UserCode is not approved") return code.store_execution_output( diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index d6bc12a8ef5..3dc580a7c4f 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -70,7 +70,7 @@ class RequestStatus(Enum): def from_usercode_status( cls, status: UserCodeStatusCollection, context: AuthedServiceContext ) -> "RequestStatus": - if status.get_approved(context): + if status.get_is_approved(context): return RequestStatus.APPROVED elif status.denied: return RequestStatus.REJECTED @@ -486,6 +486,15 @@ def code_id(self) -> UID: public_message="This type of request does not have code associated with it." ) + @property + def status_id(self) -> UID: + for change in self.changes: + if isinstance(change, UserCodeStatusChange): + return change.linked_obj.object_uid + raise SyftException( + public_message="This type of request does not have code associated with it." + ) + @property def codes(self) -> Any: for change in self.changes: @@ -615,7 +624,11 @@ def deny(self, reason: str) -> SyftSuccess: "This request already has results published to the data scientist. " "They will still be able to access those results." ) - api.code_status.update(id=self.code_id, l0_deny_reason=reason) + api.code_status.update( + id=self.code.status_link.object_uid, + decision=ApprovalDecision(status=UserCodeStatus.DENIED, reason=reason), + ) + return SyftSuccess(message=f"Request denied with reason: {reason}") return api.services.request.undo(uid=self.id, reason=reason) @@ -1081,9 +1094,7 @@ def accept_by_depositing_result(self, result: Any, force: bool = False) -> Any: pass def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: - dependencies = [] - dependencies.append(self.code_id) - return dependencies + return [self.code_id, self.status_id] @serializable() @@ -1428,10 +1439,9 @@ def mutate( undo: bool, ) -> UserCodeStatusCollection: reason: str = context.extra_kwargs.get("reason", "") - ApprovalDecision return status.mutate( value=ApprovalDecision( - decision=UserCodeStatus.DENIED if undo else self.value, reason=reason + status=UserCodeStatus.DENIED if undo else self.value, reason=reason ), server_name=context.server.name, server_id=context.server.id, diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index d9340895fa7..d61356ed36a 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -921,15 +921,15 @@ def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: return "" # Turns off the _repr_markdown_ of SyftObject def _get_visual_hierarchy( - self, server: ObjectDiff, visited: set[UID] | None = None + self, node: ObjectDiff, visited: set[UID] | None = None ) -> dict[ObjectDiff, dict]: visited = visited if visited is not None else set() - visited.add(server.object_id) + visited.add(node.object_id) _, child_types_map = self.visual_hierarchy - child_types = child_types_map.get(server.obj_type, []) - dep_ids = self.dependencies.get(server.object_id, []) + self.dependents.get( - server.object_id, [] + child_types = child_types_map.get(node.obj_type, []) + dep_ids = self.dependencies.get(node.object_id, []) + self.dependents.get( + node.object_id, [] ) result = {} @@ -1444,10 +1444,10 @@ def _create_batches( root_ids.append(diff.object_id) # type: ignore # Dependents are the reverse edges of the dependency graph - obj_dependents = {} + obj_dependents: dict = {} for parent, children in obj_dependencies.items(): for child in children: - obj_dependents[child] = obj_dependencies.get(child, []) + [parent] + obj_dependents[child] = obj_dependents.get(child, []) + [parent] for root_uid in root_ids: batch = ObjectDiffBatch.from_dependencies( diff --git a/packages/syft/src/syft/service/sync/resolve_widget.py b/packages/syft/src/syft/service/sync/resolve_widget.py index d27c106b49f..9b41fc13c70 100644 --- a/packages/syft/src/syft/service/sync/resolve_widget.py +++ b/packages/syft/src/syft/service/sync/resolve_widget.py @@ -505,20 +505,25 @@ def batch_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: return dependent_diff_widgets @property - def dependent_root_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: + def dependency_root_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: dependencies = self.obj_diff_batch.get_dependencies( include_roots=True, include_batch_root=False ) - other_roots = [ - d for d in dependencies if d.object_id in self.obj_diff_batch.global_roots - ] + + # we show these above the line + dependents = self.obj_diff_batch.get_dependents( + include_roots=False, include_batch_root=False + ) + dependent_ids = [x.object_id for x in dependents] + # we skip the ones we already show above the line in the widget + context_diffs = [d for d in dependencies if d.object_id not in dependent_ids] widgets = [ CollapsableObjectDiffWidget( diff, direction=self.obj_diff_batch.sync_direction, build_state=self.build_state, ) - for diff in other_roots + for diff in context_diffs ] return widgets @@ -559,7 +564,7 @@ def build(self) -> VBox: self.id2widget = {} batch_diff_widgets = self.batch_diff_widgets - dependent_batch_diff_widgets = self.dependent_root_diff_widgets + dependent_batch_diff_widgets = self.dependency_root_diff_widgets main_object_diff_widget = self.main_object_diff_widget self.id2widget[main_object_diff_widget.diff.object_id] = main_object_diff_widget diff --git a/packages/syft/src/syft/store/linked_obj.py b/packages/syft/src/syft/store/linked_obj.py index 448ef7aaa58..379c966d22c 100644 --- a/packages/syft/src/syft/store/linked_obj.py +++ b/packages/syft/src/syft/store/linked_obj.py @@ -44,7 +44,7 @@ def __str__(self) -> str: def resolve(self) -> SyftObject: return self._resolve() - def _resolve(self, load_cached=False) -> SyftObject: + def _resolve(self, load_cached: bool = False) -> SyftObject: api = None if load_cached and self._resolve_cache is not None: return self._resolve_cache @@ -58,16 +58,20 @@ def _resolve(self, load_cached=False) -> SyftObject: logger.error(">>> Failed to resolve object", type(api), e) raise e - def resolve_dynamic(self, context: ServerServiceContext | None, load_cached=False): + def resolve_dynamic( + self, context: ServerServiceContext | None, load_cached: bool = False + ) -> SyftObject: if context is not None: - return self.resolve_with_context(context, load_cached) + return self.resolve_with_context(context, load_cached).unwrap() else: return self._resolve(load_cached) @as_result(SyftException) def resolve_with_context( - self, context: ServerServiceContext, load_cached=False + self, context: ServerServiceContext, load_cached: bool = False ) -> Any: + if load_cached and self._resolve_cache is not None: + return self._resolve_cache if context.server is None: raise ValueError(f"context {context}'s server is None") res = ( diff --git a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py index 809a50e50a0..998de0f3ad9 100644 --- a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py +++ b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py @@ -10,7 +10,10 @@ from syft.client.syncing import compare_clients from syft.client.syncing import resolve from syft.server.worker import Worker +from syft.service.code.user_code import ApprovalDecision +from syft.service.code.user_code import UserCodeStatus from syft.service.job.job_stash import Job +from syft.service.request.request import Request from syft.service.request.request import RequestStatus from syft.service.response import SyftSuccess from syft.service.sync.resolve_widget import ResolveWidget @@ -359,7 +362,6 @@ def compute() -> int: client_low_ds.code.compute(blocking=True) assert "waiting for approval" in exc.value.public_message - assert "PENDING" in exc.value.public_message assert low_client.requests[0].status == RequestStatus.PENDING @@ -381,7 +383,12 @@ def compute() -> int: diff_before, diff_after = compare_and_resolve( from_client=high_client, to_client=low_client, share_private_data=True ) - assert len(diff_before.batches) == 1 and diff_before.batches[0].root_type is Job + assert len(diff_before.batches) == 2 + root_types = [x.root_type for x in diff_before.batches] + assert Job in root_types + assert ( + Request in root_types + ) # we have not configured it to count UserCode as a root type assert low_client.requests[0].status == RequestStatus.APPROVED assert client_low_ds.code.compute().get() == 42 @@ -414,7 +421,10 @@ def compute() -> int: assert low_client.requests[0].status == RequestStatus.REJECTED # Un-deny. NOTE: not supported by current UX, this is just used to re-deny on high side - low_client.api.code_status.update(id=request_low.code_id, l0_deny_reason=None) + low_client.api.code_status.update( + id=request_low.status_id, + decision=ApprovalDecision(status=UserCodeStatus.PENDING), + ) assert low_client.requests[0].status == RequestStatus.PENDING # Sync request to high side From d5d704184e67f11e2a378ba3e48c1bcf174a32af Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Wed, 18 Sep 2024 11:07:57 +0200 Subject: [PATCH 06/15] fix tests --- packages/syft/src/syft/service/code/user_code.py | 2 +- packages/syft/src/syft/service/request/request.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 3353868b954..25cbed48256 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -578,7 +578,7 @@ def status(self) -> UserCodeStatusCollection: def get_status( self, context: AuthedServiceContext | None ) -> UserCodeStatusCollection: - return self.status_link.resolve_dynamic(context, load_cached=True) + return self.status_link.resolve_dynamic(context, load_cached=False) @property def input_owners(self) -> list[str] | None: diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 1bc33abde65..4382ce383c7 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -491,7 +491,7 @@ def code_id(self) -> UID: def status_id(self) -> UID: for change in self.changes: if isinstance(change, UserCodeStatusChange): - return change.linked_obj.object_uid + return change.linked_obj.object_uid # type: ignore raise SyftException( public_message="This type of request does not have code associated with it." ) From d4a6fb9ffa6796761d68ad5044a4c970ac609b8d Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Wed, 18 Sep 2024 11:51:16 +0200 Subject: [PATCH 07/15] fix syncing test --- tests/integration/local/twin_api_sync_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integration/local/twin_api_sync_test.py b/tests/integration/local/twin_api_sync_test.py index 9e8d28adaa7..f8c137e5967 100644 --- a/tests/integration/local/twin_api_sync_test.py +++ b/tests/integration/local/twin_api_sync_test.py @@ -136,8 +136,7 @@ def compute(query): endpoint_path="testapi.query", endpoint_timeout=expected_timeout_after ) widget = sy.sync(from_client=high_client, to_client=low_client) - result = widget[0].click_sync() - assert result, result + widget._sync_all() timeout_after = ( full_low_worker.python_server.services.api.stash.get_all( From 1a99aa8120e786120ab251eb736e34ebe9be5887 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Wed, 18 Sep 2024 12:11:22 +0200 Subject: [PATCH 08/15] fix syncing test --- .../bigquery/sync/01-setup-high-low-datasites.ipynb | 7 +------ .../bigquery/sync/02-configure-api-and-sync.ipynb | 7 +------ .../bigquery/sync/03-ds-submit-request.ipynb | 7 +------ .../bigquery/sync/04-do-review-requests.ipynb | 12 ++++-------- .../scenarios/bigquery/sync/05-ds-get-results.ipynb | 7 +------ .../src/syft/util/notebook_ui/components/sync.py | 6 +++++- 6 files changed, 13 insertions(+), 33 deletions(-) diff --git a/notebooks/scenarios/bigquery/sync/01-setup-high-low-datasites.ipynb b/notebooks/scenarios/bigquery/sync/01-setup-high-low-datasites.ipynb index 633a73c38e4..f7849f15dc8 100644 --- a/notebooks/scenarios/bigquery/sync/01-setup-high-low-datasites.ipynb +++ b/notebooks/scenarios/bigquery/sync/01-setup-high-low-datasites.ipynb @@ -312,11 +312,6 @@ } ], "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -327,7 +322,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/notebooks/scenarios/bigquery/sync/02-configure-api-and-sync.ipynb b/notebooks/scenarios/bigquery/sync/02-configure-api-and-sync.ipynb index 094841ef58e..02629a750df 100644 --- a/notebooks/scenarios/bigquery/sync/02-configure-api-and-sync.ipynb +++ b/notebooks/scenarios/bigquery/sync/02-configure-api-and-sync.ipynb @@ -602,11 +602,6 @@ } ], "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -617,7 +612,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/notebooks/scenarios/bigquery/sync/03-ds-submit-request.ipynb b/notebooks/scenarios/bigquery/sync/03-ds-submit-request.ipynb index a2759038134..9d2abea0af7 100644 --- a/notebooks/scenarios/bigquery/sync/03-ds-submit-request.ipynb +++ b/notebooks/scenarios/bigquery/sync/03-ds-submit-request.ipynb @@ -304,11 +304,6 @@ } ], "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -319,7 +314,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/notebooks/scenarios/bigquery/sync/04-do-review-requests.ipynb b/notebooks/scenarios/bigquery/sync/04-do-review-requests.ipynb index 4eec3d6e7b1..a02acfb211a 100644 --- a/notebooks/scenarios/bigquery/sync/04-do-review-requests.ipynb +++ b/notebooks/scenarios/bigquery/sync/04-do-review-requests.ipynb @@ -209,8 +209,9 @@ "metadata": {}, "outputs": [], "source": [ - "assert len(diffs.batches) == 1\n", - "assert diffs.batches[0].root_diff.obj_type.__qualname__ == \"Job\"" + "batch_root_strs = [x.root_diff.obj_type.__qualname__ for x in diffs.batches]\n", + "assert len(diffs.batches) == 3\n", + "assert \"Job\" in batch_root_strs" ] }, { @@ -285,11 +286,6 @@ } ], "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -300,7 +296,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/notebooks/scenarios/bigquery/sync/05-ds-get-results.ipynb b/notebooks/scenarios/bigquery/sync/05-ds-get-results.ipynb index 1e61e0d8587..9523bbb75fa 100644 --- a/notebooks/scenarios/bigquery/sync/05-ds-get-results.ipynb +++ b/notebooks/scenarios/bigquery/sync/05-ds-get-results.ipynb @@ -126,11 +126,6 @@ } ], "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -141,7 +136,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/packages/syft/src/syft/util/notebook_ui/components/sync.py b/packages/syft/src/syft/util/notebook_ui/components/sync.py index a0fd095e81d..693d9549367 100644 --- a/packages/syft/src/syft/util/notebook_ui/components/sync.py +++ b/packages/syft/src/syft/util/notebook_ui/components/sync.py @@ -13,6 +13,7 @@ from ....service.user.user import UserView from ....types.datetime import DateTime from ....types.datetime import format_timedelta_human_readable +from ....types.errors import SyftException from ....types.syft_object import SYFT_OBJECT_VERSION_1 from ....types.syft_object import SyftObject from ..icons import Icon @@ -112,7 +113,10 @@ def get_updated_by(self) -> str: user_view: UserView | None = None if isinstance(self.object, UserCode): - user_view = self.object.user + try: + user_view = self.object.user + except SyftException: + pass # nosec if isinstance(user_view, UserView): return f"Created by {user_view.email}" From e5453aa9f0af2a3e3b30b828c9c53815f1b5c81f Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Thu, 19 Sep 2024 11:19:53 +0200 Subject: [PATCH 09/15] migration tests --- .gitignore | 3 + .../0-prepare-migration-data.ipynb | 7 +- .../1-dump-database-to-file.ipynb | 7 +- .../2-migrate-from-file.ipynb | 72 ++++++++++--------- .../syft/src/syft/service/code/user_code.py | 33 +++++++++ .../service/migration/migration_service.py | 39 +++++++--- .../migration/object_migration_state.py | 29 ++++++++ packages/syft/src/syft/store/db/stash.py | 11 ++- packages/syft/src/syft/types/transforms.py | 2 + 9 files changed, 149 insertions(+), 54 deletions(-) diff --git a/.gitignore b/.gitignore index de762cac9c8..7285116f6ca 100644 --- a/.gitignore +++ b/.gitignore @@ -88,3 +88,6 @@ packages/grid/helm/examples/dev/migration.yaml notebooks/scenarios/bigquery/*.json + +notebooks/tutorials/version-upgrades/*.yaml +notebooks/tutorials/version-upgrades/*.blob diff --git a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb index 03d4e29cfc5..9425c5de960 100644 --- a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb +++ b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb @@ -236,11 +236,6 @@ } ], "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -251,7 +246,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb b/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb index bc1bd06f036..e8141ac1da9 100644 --- a/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb +++ b/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb @@ -129,11 +129,6 @@ } ], "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -144,7 +139,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb b/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb index 930e2212911..906a9281ba4 100644 --- a/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb +++ b/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb @@ -36,7 +36,7 @@ " name=\"test_upgradability\",\n", " dev_mode=True,\n", " reset=True,\n", - " port=\"auto\",\n", + " # port=\"auto\",\n", ")\n", "\n", "client = server.login(email=\"info@openmined.org\", password=\"changethis\")" @@ -85,8 +85,8 @@ "metadata": {}, "outputs": [], "source": [ - "res = client.load_migration_data(blob_path)\n", - "assert isinstance(res, sy.SyftSuccess), res.message" + "# syft absolute\n", + "from syft.service.migration.object_migration_state import MigrationData" ] }, { @@ -96,15 +96,18 @@ "metadata": {}, "outputs": [], "source": [ - "res" + "migration_data = MigrationData.from_file(blob_path)" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "8", "metadata": {}, + "outputs": [], "source": [ - "# Post migration tests" + "res = client.load_migration_data(blob_path)\n", + "assert isinstance(res, sy.SyftSuccess), res.message" ] }, { @@ -114,17 +117,15 @@ "metadata": {}, "outputs": [], "source": [ - "assert len(client.users.get_all()) == 2" + "res" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "id": "10", "metadata": {}, - "outputs": [], "source": [ - "client_ds = server.login(email=\"ds@openmined.org\", password=\"pw\")" + "# Post migration tests" ] }, { @@ -133,7 +134,9 @@ "id": "11", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "assert len(client.users.get_all()) == 2" + ] }, { "cell_type": "code", @@ -141,6 +144,16 @@ "id": "12", "metadata": {}, "outputs": [], + "source": [ + "client_ds = server.login(email=\"ds@openmined.org\", password=\"pw\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], "source": [ "# syft absolute\n", "from syft.client.api import APIRegistry" @@ -149,7 +162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -159,7 +172,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -169,7 +182,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -179,7 +192,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -193,7 +206,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -204,7 +217,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -215,7 +228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -225,7 +238,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -235,7 +248,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -245,7 +258,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -255,7 +268,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -265,7 +278,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -276,7 +289,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -287,18 +300,13 @@ { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "27", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -309,7 +317,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 25cbed48256..b21b269b4f8 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -51,6 +51,7 @@ from ...types.dicttuple import DictTuple from ...types.errors import SyftException from ...types.result import as_result +from ...types.syft_migration import migrate from ...types.syft_object import PartialSyftObject from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 @@ -58,7 +59,9 @@ from ...types.syncable_object import SyncableSyftObject from ...types.transforms import TransformContext from ...types.transforms import add_server_uid_for_key +from ...types.transforms import drop from ...types.transforms import generate_id +from ...types.transforms import make_set_default from ...types.transforms import transform from ...types.uid import UID from ...util.decorators import deprecated @@ -373,6 +376,31 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: return [self.user_code_link.object_uid] +@migrate(UserCodeStatusCollectionV1, UserCodeStatusCollection) +def migrate_user_code_to_v2() -> list[Callable]: + def update_statusdict(context: TransformContext) -> TransformContext: + res = {} + for server_identity, (status, reason) in context.obj.status_dict.items(): + res[server_identity] = ApprovalDecision(status=status, reason=reason) + context.output["status_dict"] = res + return context + + def set_user_verify_key(context: TransformContext) -> TransformContext: + authed_context = context.to_server_context() + user_code = context.obj.user_code_link.resolve_with_context( + authed_context + ).unwrap() + context.output["user_verify_key"] = user_code.user_verify_key + return context + + return [ + make_set_default("was_requested_on_lowside", False), + make_set_default("_has_readable_outputs_cache", None), + update_statusdict, + set_user_verify_key, + ] + + @serializable() class UserCodeV1(SyncableSyftObject): # version @@ -2060,3 +2088,8 @@ def load_approved_policy_code( load_policy_code(user_code.input_policy_type) if isinstance(user_code.output_policy_type, UserPolicy): load_policy_code(user_code.output_policy_type) + + +@migrate(UserCodeV1, UserCode) +def migrate_user_code_to_v2() -> list[Callable]: + return [drop("l0_deny_reason"), drop("_has_output_read_permissions_cache")] diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index b1d461e4e80..78bcc30458e 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -24,6 +24,7 @@ from ..response import SyftSuccess from ..service import AbstractService from ..service import service_method +from ..user.user import User from ..user.user_roles import ADMIN_ROLE_LEVEL from ..worker.utils import DEFAULT_WORKER_POOL_NAME from .object_migration_state import MigrationData @@ -256,6 +257,7 @@ def _create_migrated_objects( context: AuthedServiceContext, migrated_objects: list[SyftObject], ignore_existing: bool = True, + skip_check_type: bool = False, ) -> SyftSuccess: for migrated_object in migrated_objects: stash = self._search_stash_for_klass( @@ -265,6 +267,7 @@ def _create_migrated_objects( result = stash.set( context.credentials, obj=migrated_object, + skip_check_type=skip_check_type, ) # Exception from the new Error Handling pattern, no need to change if result.is_err(): @@ -286,14 +289,20 @@ def _update_migrated_objects( self, context: AuthedServiceContext, migrated_objects: list[SyftObject] ) -> SyftSuccess: for migrated_object in migrated_objects: - stash = self._search_stash_for_klass( - context, type(migrated_object) - ).unwrap() + if ( + isinstance(migrated_object, User) + and migrated_object.verify_key == context.server.verify_key + ): + self.stash.update_root_user(context, migrated_object).unwrap() + else: + stash = self._search_stash_for_klass( + context, type(migrated_object) + ).unwrap() - stash.update( - context.credentials, - obj=migrated_object, - ).unwrap() + stash.update( + context.credentials, + obj=migrated_object, + ).unwrap() return SyftSuccess(message="Updated migration objects!") @@ -304,6 +313,12 @@ def _migrate_objects( migration_objects: dict[type[SyftObject], list[SyftObject]], ) -> list[SyftObject]: migrated_objects = [] + + # def get_sorting_key_migration(klass): + # canonical_name = klass.__canonical_name__ + # latest_version = SyftObjectRegistry.get_latest_version(canonical_name) + # return getattr(latest_version, "__migration_priority__", 0) + for klass, objects in migration_objects.items(): canonical_name = klass.__canonical_name__ latest_version = SyftObjectRegistry.get_latest_version(canonical_name) @@ -435,11 +450,19 @@ def apply_migration_data( "please use 'client.load_migration_data' instead." ) + pre_objects = [ + o for objects in migration_data.store_objects.values() for o in objects + ] + self._create_migrated_objects( + context, pre_objects, skip_check_type=True + ).unwrap() + print("UPDATING") + # migrate + apply store objects migrated_objects = self._migrate_objects( context, migration_data.store_objects ).unwrap() - self._create_migrated_objects(context, migrated_objects).unwrap() + self._update_migrated_objects(context, migrated_objects).unwrap() # migrate+apply action objects migrated_actionobjects = self._migrate_objects( diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py index 22363d867f2..3920b70d3be 100644 --- a/packages/syft/src/syft/service/migration/object_migration_state.py +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -6,6 +6,7 @@ from typing import Any # third party +import sqlalchemy from typing_extensions import Self import yaml @@ -16,6 +17,7 @@ from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash +from ...store.db.stash import with_session from ...store.document_store import PartitionKey from ...store.document_store_errors import NotFoundException from ...types.blob_storage import BlobStorageEntry @@ -32,7 +34,9 @@ from ...types.transforms import make_set_default from ...types.uid import UID from ...util.util import prompt_warning_message +from ..context import AuthedServiceContext from ..response import SyftSuccess +from ..user.user import User from ..worker.utils import DEFAULT_WORKER_POOL_NAME from ..worker.worker_image import SyftWorkerImage from ..worker.worker_pool import SyftWorker @@ -78,6 +82,31 @@ def get_by_name( filters={"canonical_name": canonical_name}, ).unwrap() + @as_result(SyftException) + @with_session + def update_root_user( + self, + context: AuthedServiceContext, + root_user: User, + session: sqlalchemy.orm.session, + ) -> SyftSuccess: + user_stash = context.server.services.user.stash + existing_user = user_stash.get_by_verify_key( + context.credentials, root_user.verify_key + ).unwrap() + stmt = user_stash.table.update().where( + user_stash.table.c.id == existing_user.id + ) + stmt = stmt.values(id=root_user.id) + result = session.execute(stmt) + session.commit() + if result.rowcount == 0: + raise NotFoundException( + f"User: {root_user.id} not found or no permission to update." + ) + + return user_stash.update(context.credentials, obj=root_user).unwrap() + @serializable() class StoreMetadata(SyftBaseObject): diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index aec2a2ed9c5..1101b09d1f4 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -204,7 +204,13 @@ def is_unique(self, obj: StashT, session: Session = None) -> bool: return False elif len(results) == 1: result = results[0] - return result.id == obj.id + res = result.id == obj.id + if not res: + # third party + import ipdb + + ipdb.set_trace() + return res return True @with_session @@ -360,8 +366,9 @@ def set( add_storage_permission: bool = True, # TODO: check the default value ignore_duplicates: bool = False, session: Session = None, + skip_check_type: bool = False, ) -> StashT: - if not self.allow_any_type: + if not self.allow_any_type and not skip_check_type: self.check_type(obj, self.object_type).unwrap() uid = obj.id diff --git a/packages/syft/src/syft/types/transforms.py b/packages/syft/src/syft/types/transforms.py index 60e9722a029..7ff980e692c 100644 --- a/packages/syft/src/syft/types/transforms.py +++ b/packages/syft/src/syft/types/transforms.py @@ -30,6 +30,8 @@ class TransformContext(Context): @classmethod def from_context(cls, obj: Any, context: Context | None = None) -> Self: + if isinstance(context, TransformContext): + return context t_context = cls() t_context.obj = obj try: From 5769f67717a0001858c1d1a31e15d87d73e98fb7 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 19 Sep 2024 14:03:00 +0200 Subject: [PATCH 10/15] fix migration tests --- .../service/migration/migration_service.py | 17 +++++++---------- packages/syft/src/syft/store/db/stash.py | 18 +++++++++++------- .../src/syft/store/document_store_errors.py | 4 ++++ packages/syft/src/syft/types/transforms.py | 2 -- 4 files changed, 22 insertions(+), 19 deletions(-) diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 78bcc30458e..aa5c73b5a92 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -9,6 +9,7 @@ from ...store.db.db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException +from ...store.document_store_errors import UniqueConstraintException from ...types.blob_storage import BlobStorageEntry from ...types.errors import SyftException from ...types.result import as_result @@ -24,7 +25,6 @@ from ..response import SyftSuccess from ..service import AbstractService from ..service import service_method -from ..user.user import User from ..user.user_roles import ADMIN_ROLE_LEVEL from ..worker.utils import DEFAULT_WORKER_POOL_NAME from .object_migration_state import MigrationData @@ -289,20 +289,17 @@ def _update_migrated_objects( self, context: AuthedServiceContext, migrated_objects: list[SyftObject] ) -> SyftSuccess: for migrated_object in migrated_objects: - if ( - isinstance(migrated_object, User) - and migrated_object.verify_key == context.server.verify_key - ): - self.stash.update_root_user(context, migrated_object).unwrap() - else: - stash = self._search_stash_for_klass( - context, type(migrated_object) - ).unwrap() + stash = self._search_stash_for_klass( + context, type(migrated_object) + ).unwrap() + try: stash.update( context.credentials, obj=migrated_object, ).unwrap() + except UniqueConstraintException as e: + print(f"Failed to update {migrated_object}: {e}") return SyftSuccess(message="Updated migration objects!") diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 1101b09d1f4..85860821866 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -42,6 +42,7 @@ from ...util.telemetry import instrument from ..document_store_errors import NotFoundException from ..document_store_errors import StashException +from ..document_store_errors import UniqueConstraintException from .db import DBManager from .query import Query from .schema import PostgresBase @@ -205,11 +206,6 @@ def is_unique(self, obj: StashT, session: Session = None) -> bool: elif len(results) == 1: result = results[0] res = result.id == obj.id - if not res: - # third party - import ipdb - - ipdb.set_trace() return res return True @@ -434,7 +430,13 @@ def apply_partial_update( self.object_type.model_validate(original_obj) return original_obj - @as_result(StashException, NotFoundException, AttributeError, ValidationError) + @as_result( + StashException, + NotFoundException, + AttributeError, + ValidationError, + UniqueConstraintException, + ) @with_session def update( self, @@ -461,7 +463,9 @@ def update( # TODO has_permission is not used if not self.is_unique(obj): - raise StashException(f"Some fields are not unique for {type(obj).__name__}") + raise UniqueConstraintException( + f"Some fields are not unique for {type(obj).__name__} and unique fields {self.unique_fields}" + ) stmt = self.table.update().where(self._get_field_filter("id", obj.id)) stmt = self._apply_permission_filter( diff --git a/packages/syft/src/syft/store/document_store_errors.py b/packages/syft/src/syft/store/document_store_errors.py index 69da6b73a8f..04fb6777897 100644 --- a/packages/syft/src/syft/store/document_store_errors.py +++ b/packages/syft/src/syft/store/document_store_errors.py @@ -14,6 +14,10 @@ class StashException(SyftException): public_message = "There was an error retrieving data. Contact your admin." +class UniqueConstraintException(StashException): + public_message = "Another item with the same unique constraint already exists." + + class ObjectCRUDPermissionException(SyftException): public_message = "You do not have permission to perform this action." diff --git a/packages/syft/src/syft/types/transforms.py b/packages/syft/src/syft/types/transforms.py index 7ff980e692c..60e9722a029 100644 --- a/packages/syft/src/syft/types/transforms.py +++ b/packages/syft/src/syft/types/transforms.py @@ -30,8 +30,6 @@ class TransformContext(Context): @classmethod def from_context(cls, obj: Any, context: Context | None = None) -> Self: - if isinstance(context, TransformContext): - return context t_context = cls() t_context.obj = obj try: From c111fd46b3efbe2a23516f819af278880112560d Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 19 Sep 2024 14:32:44 +0200 Subject: [PATCH 11/15] fix lint --- packages/syft/src/syft/service/code/user_code.py | 10 +++++++++- .../syft/src/syft/service/worker/worker_image_stash.py | 1 + .../syft/src/syft/service/worker/worker_pool_stash.py | 1 + packages/syft/src/syft/service/worker/worker_stash.py | 1 + 4 files changed, 12 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index b21b269b4f8..fea4927cb82 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -377,9 +377,13 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: @migrate(UserCodeStatusCollectionV1, UserCodeStatusCollection) -def migrate_user_code_to_v2() -> list[Callable]: +def migrate_user_code_status_to_v2() -> list[Callable]: def update_statusdict(context: TransformContext) -> TransformContext: res = {} + if not isinstance(context.obj, UserCodeStatusCollectionV1): + raise Exception("Invalid object type") + if context.output is None: + raise Exception("Output is None") for server_identity, (status, reason) in context.obj.status_dict.items(): res[server_identity] = ApprovalDecision(status=status, reason=reason) context.output["status_dict"] = res @@ -387,6 +391,10 @@ def update_statusdict(context: TransformContext) -> TransformContext: def set_user_verify_key(context: TransformContext) -> TransformContext: authed_context = context.to_server_context() + if not isinstance(context.obj, UserCodeStatusCollectionV1): + raise Exception("Invalid object type") + if context.output is None: + raise Exception("Output is None") user_code = context.obj.user_code_link.resolve_with_context( authed_context ).unwrap() diff --git a/packages/syft/src/syft/service/worker/worker_image_stash.py b/packages/syft/src/syft/service/worker/worker_image_stash.py index dc220905839..29755e9ef07 100644 --- a/packages/syft/src/syft/service/worker/worker_image_stash.py +++ b/packages/syft/src/syft/service/worker/worker_image_stash.py @@ -33,6 +33,7 @@ def set( add_storage_permission: bool = True, ignore_duplicates: bool = False, session: Session = None, + skip_check_type: bool = False, ) -> SyftWorkerImage: # By default syft images have all read permission add_permissions = [] if add_permissions is None else add_permissions diff --git a/packages/syft/src/syft/service/worker/worker_pool_stash.py b/packages/syft/src/syft/service/worker/worker_pool_stash.py index 81a4f4741d2..3ae0a2d9ec2 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_stash.py +++ b/packages/syft/src/syft/service/worker/worker_pool_stash.py @@ -42,6 +42,7 @@ def set( add_storage_permission: bool = True, ignore_duplicates: bool = False, session: Session = None, + skip_check_type: bool = False, ) -> WorkerPool: # By default all worker pools have all read permission add_permissions = [] if add_permissions is None else add_permissions diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index 48a192ecd19..d64314a5d81 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -35,6 +35,7 @@ def set( add_storage_permission: bool = True, ignore_duplicates: bool = False, session: Session = None, + skip_check_type: bool = False, ) -> SyftWorker: # By default all worker pools have all read permission add_permissions = [] if add_permissions is None else add_permissions From 5b362bad9e9fab624868ec0600c865b799f9e179 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 19 Sep 2024 16:00:33 +0200 Subject: [PATCH 12/15] filter out errored items in migrartion --- .../service/migration/migration_service.py | 77 +++++++++---------- 1 file changed, 37 insertions(+), 40 deletions(-) diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index aa5c73b5a92..dce89d4ed6f 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -9,7 +9,6 @@ from ...store.db.db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException -from ...store.document_store_errors import UniqueConstraintException from ...types.blob_storage import BlobStorageEntry from ...types.errors import SyftException from ...types.result import as_result @@ -255,34 +254,38 @@ def create_migrated_objects( def _create_migrated_objects( self, context: AuthedServiceContext, - migrated_objects: list[SyftObject], + migrated_objects: dict[type[SyftObject], list[SyftObject]], ignore_existing: bool = True, skip_check_type: bool = False, - ) -> SyftSuccess: - for migrated_object in migrated_objects: - stash = self._search_stash_for_klass( - context, type(migrated_object) - ).unwrap() + ) -> dict[type[SyftObject], list[SyftObject]]: + created_objects: dict[type[SyftObject], list[SyftObject]] = {} + for key, objects in migrated_objects.items(): + created_objects[key] = [] + for migrated_object in objects: + stash = self._search_stash_for_klass( + context, type(migrated_object) + ).unwrap() - result = stash.set( - context.credentials, - obj=migrated_object, - skip_check_type=skip_check_type, - ) - # Exception from the new Error Handling pattern, no need to change - if result.is_err(): - # TODO: subclass a DuplicationKeyError - if ignore_existing and ( - "Duplication Key Error" in result.err()._private_message # type: ignore - or "Duplication Key Error" in result.err().public_message # type: ignore - ): - print( - f"{type(migrated_object)} #{migrated_object.id} already exists" - ) - continue - else: - result.unwrap() # this will raise the exception inside the wrapper - return SyftSuccess(message="Created migrate objects!") + result = stash.set( + context.credentials, + obj=migrated_object, + skip_check_type=skip_check_type, + ) + # Exception from the new Error Handling pattern, no need to change + if result.is_err(): + # TODO: subclass a DuplicationKeyError + if ignore_existing and ( + "Duplication Key Error" in result.err()._private_message # type: ignore + or "Duplication Key Error" in result.err().public_message # type: ignore + ): + print( + f"{type(migrated_object)} #{migrated_object.id} already exists" + ) + continue + else: + result.unwrap() # this will raise the exception inside the wrapper + created_objects[key].append(result.unwrap()) + return created_objects @as_result(SyftException) def _update_migrated_objects( @@ -293,13 +296,10 @@ def _update_migrated_objects( context, type(migrated_object) ).unwrap() - try: - stash.update( - context.credentials, - obj=migrated_object, - ).unwrap() - except UniqueConstraintException as e: - print(f"Failed to update {migrated_object}: {e}") + stash.update( + context.credentials, + obj=migrated_object, + ).unwrap() return SyftSuccess(message="Updated migration objects!") @@ -447,17 +447,14 @@ def apply_migration_data( "please use 'client.load_migration_data' instead." ) - pre_objects = [ - o for objects in migration_data.store_objects.values() for o in objects - ] - self._create_migrated_objects( - context, pre_objects, skip_check_type=True + created_objects = self._create_migrated_objects( + context, migration_data.store_objects, skip_check_type=True ).unwrap() - print("UPDATING") # migrate + apply store objects migrated_objects = self._migrate_objects( - context, migration_data.store_objects + context, + created_objects, ).unwrap() self._update_migrated_objects(context, migrated_objects).unwrap() From 4e3f245e9db44bf89097500ceb3f36950b849b90 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 19 Sep 2024 16:07:52 +0200 Subject: [PATCH 13/15] remove unused code --- .../service/migration/migration_service.py | 8 ++--- .../migration/object_migration_state.py | 29 ------------------- 2 files changed, 2 insertions(+), 35 deletions(-) diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index dce89d4ed6f..eef1a113af7 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -246,9 +246,10 @@ def create_migrated_objects( migrated_objects: list[SyftObject], ignore_existing: bool = True, ) -> SyftSuccess: - return self._create_migrated_objects( + self._create_migrated_objects( context, migrated_objects, ignore_existing=ignore_existing ).unwrap() + return SyftSuccess(message="Created migration objects!") @as_result(SyftException) def _create_migrated_objects( @@ -311,11 +312,6 @@ def _migrate_objects( ) -> list[SyftObject]: migrated_objects = [] - # def get_sorting_key_migration(klass): - # canonical_name = klass.__canonical_name__ - # latest_version = SyftObjectRegistry.get_latest_version(canonical_name) - # return getattr(latest_version, "__migration_priority__", 0) - for klass, objects in migration_objects.items(): canonical_name = klass.__canonical_name__ latest_version = SyftObjectRegistry.get_latest_version(canonical_name) diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py index 3920b70d3be..22363d867f2 100644 --- a/packages/syft/src/syft/service/migration/object_migration_state.py +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -6,7 +6,6 @@ from typing import Any # third party -import sqlalchemy from typing_extensions import Self import yaml @@ -17,7 +16,6 @@ from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash -from ...store.db.stash import with_session from ...store.document_store import PartitionKey from ...store.document_store_errors import NotFoundException from ...types.blob_storage import BlobStorageEntry @@ -34,9 +32,7 @@ from ...types.transforms import make_set_default from ...types.uid import UID from ...util.util import prompt_warning_message -from ..context import AuthedServiceContext from ..response import SyftSuccess -from ..user.user import User from ..worker.utils import DEFAULT_WORKER_POOL_NAME from ..worker.worker_image import SyftWorkerImage from ..worker.worker_pool import SyftWorker @@ -82,31 +78,6 @@ def get_by_name( filters={"canonical_name": canonical_name}, ).unwrap() - @as_result(SyftException) - @with_session - def update_root_user( - self, - context: AuthedServiceContext, - root_user: User, - session: sqlalchemy.orm.session, - ) -> SyftSuccess: - user_stash = context.server.services.user.stash - existing_user = user_stash.get_by_verify_key( - context.credentials, root_user.verify_key - ).unwrap() - stmt = user_stash.table.update().where( - user_stash.table.c.id == existing_user.id - ) - stmt = stmt.values(id=root_user.id) - result = session.execute(stmt) - session.commit() - if result.rowcount == 0: - raise NotFoundException( - f"User: {root_user.id} not found or no permission to update." - ) - - return user_stash.update(context.credentials, obj=root_user).unwrap() - @serializable() class StoreMetadata(SyftBaseObject): From 3ba39aa7e1323b696441a4d73be1f892ddea44ba Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 19 Sep 2024 17:20:56 +0200 Subject: [PATCH 14/15] revert some notebook changes --- .../sync/01-setup-high-low-datasites.ipynb | 7 +- .../sync/02-configure-api-and-sync.ipynb | 7 +- .../bigquery/sync/03-ds-submit-request.ipynb | 7 +- .../bigquery/sync/05-ds-get-results.ipynb | 7 +- .../0-prepare-migration-data.ipynb | 7 +- .../1-dump-database-to-file.ipynb | 7 +- .../2-migrate-from-file.ipynb | 72 +++++++++---------- 7 files changed, 68 insertions(+), 46 deletions(-) diff --git a/notebooks/scenarios/bigquery/sync/01-setup-high-low-datasites.ipynb b/notebooks/scenarios/bigquery/sync/01-setup-high-low-datasites.ipynb index fb31b955983..691c58d4b00 100644 --- a/notebooks/scenarios/bigquery/sync/01-setup-high-low-datasites.ipynb +++ b/notebooks/scenarios/bigquery/sync/01-setup-high-low-datasites.ipynb @@ -312,6 +312,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -322,7 +327,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/notebooks/scenarios/bigquery/sync/02-configure-api-and-sync.ipynb b/notebooks/scenarios/bigquery/sync/02-configure-api-and-sync.ipynb index a5e652bfd77..99274aba2a8 100644 --- a/notebooks/scenarios/bigquery/sync/02-configure-api-and-sync.ipynb +++ b/notebooks/scenarios/bigquery/sync/02-configure-api-and-sync.ipynb @@ -599,6 +599,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -609,7 +614,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/notebooks/scenarios/bigquery/sync/03-ds-submit-request.ipynb b/notebooks/scenarios/bigquery/sync/03-ds-submit-request.ipynb index 9d2abea0af7..a2759038134 100644 --- a/notebooks/scenarios/bigquery/sync/03-ds-submit-request.ipynb +++ b/notebooks/scenarios/bigquery/sync/03-ds-submit-request.ipynb @@ -304,6 +304,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -314,7 +319,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/notebooks/scenarios/bigquery/sync/05-ds-get-results.ipynb b/notebooks/scenarios/bigquery/sync/05-ds-get-results.ipynb index 9523bbb75fa..1e61e0d8587 100644 --- a/notebooks/scenarios/bigquery/sync/05-ds-get-results.ipynb +++ b/notebooks/scenarios/bigquery/sync/05-ds-get-results.ipynb @@ -126,6 +126,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -136,7 +141,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb index 9425c5de960..03d4e29cfc5 100644 --- a/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb +++ b/notebooks/tutorials/version-upgrades/0-prepare-migration-data.ipynb @@ -236,6 +236,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -246,7 +251,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb b/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb index e8141ac1da9..bc1bd06f036 100644 --- a/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb +++ b/notebooks/tutorials/version-upgrades/1-dump-database-to-file.ipynb @@ -129,6 +129,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -139,7 +144,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb b/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb index 906a9281ba4..930e2212911 100644 --- a/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb +++ b/notebooks/tutorials/version-upgrades/2-migrate-from-file.ipynb @@ -36,7 +36,7 @@ " name=\"test_upgradability\",\n", " dev_mode=True,\n", " reset=True,\n", - " # port=\"auto\",\n", + " port=\"auto\",\n", ")\n", "\n", "client = server.login(email=\"info@openmined.org\", password=\"changethis\")" @@ -85,8 +85,8 @@ "metadata": {}, "outputs": [], "source": [ - "# syft absolute\n", - "from syft.service.migration.object_migration_state import MigrationData" + "res = client.load_migration_data(blob_path)\n", + "assert isinstance(res, sy.SyftSuccess), res.message" ] }, { @@ -96,18 +96,15 @@ "metadata": {}, "outputs": [], "source": [ - "migration_data = MigrationData.from_file(blob_path)" + "res" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "id": "8", "metadata": {}, - "outputs": [], "source": [ - "res = client.load_migration_data(blob_path)\n", - "assert isinstance(res, sy.SyftSuccess), res.message" + "# Post migration tests" ] }, { @@ -117,41 +114,31 @@ "metadata": {}, "outputs": [], "source": [ - "res" - ] - }, - { - "cell_type": "markdown", - "id": "10", - "metadata": {}, - "source": [ - "# Post migration tests" + "assert len(client.users.get_all()) == 2" ] }, { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "10", "metadata": {}, "outputs": [], "source": [ - "assert len(client.users.get_all()) == 2" + "client_ds = server.login(email=\"ds@openmined.org\", password=\"pw\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "11", "metadata": {}, "outputs": [], - "source": [ - "client_ds = server.login(email=\"ds@openmined.org\", password=\"pw\")" - ] + "source": [] }, { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -162,7 +149,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -172,7 +159,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -182,7 +169,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -192,7 +179,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -206,7 +193,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -217,7 +204,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -228,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -238,7 +225,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -248,7 +235,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -258,7 +245,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -268,7 +255,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -278,7 +265,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -289,7 +276,7 @@ { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -300,13 +287,18 @@ { "cell_type": "code", "execution_count": null, - "id": "27", + "id": "26", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -317,7 +309,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.10.13" } }, "nbformat": 4, From a0d1acaba427dd618056af37977af3cb2b8cf8c1 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 19 Sep 2024 17:26:17 +0200 Subject: [PATCH 15/15] rm file --- .../f6c08d0ae735435582a74331d6b9984e.json | 30 ------------------- 1 file changed, 30 deletions(-) delete mode 100644 packages/syft/src/syft/protocol/f6c08d0ae735435582a74331d6b9984e.json diff --git a/packages/syft/src/syft/protocol/f6c08d0ae735435582a74331d6b9984e.json b/packages/syft/src/syft/protocol/f6c08d0ae735435582a74331d6b9984e.json deleted file mode 100644 index 0514e1b98d7..00000000000 --- a/packages/syft/src/syft/protocol/f6c08d0ae735435582a74331d6b9984e.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "1": { - "release_name": "0.9.1.json" - }, - "dev": { - "object_versions": { - "ApprovalDecision": { - "1": { - "version": 1, - "hash": "ecce7c6e01af68b0c0a73605f0c2226917f0784ecce69e9f64ce004b243252d4", - "action": "add" - } - }, - "UserCodeStatusCollection": { - "2": { - "version": 2, - "hash": "aacbdcc19141d96914ab10b6c3f9f4684fb3f71d405254df70602655539044c7", - "action": "add" - } - }, - "UserCode": { - "2": { - "version": 2, - "hash": "c127aa856b208f06f50131cd114a910b5b147e252efc9c962f0fe424a1edd264", - "action": "add" - } - } - } - } -}