From 1dec1aac75e2b9ca634a67a1a84344447d23f65f Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Fri, 19 Jul 2024 15:17:00 +0200 Subject: [PATCH 1/5] add code reloader --- packages/syft/src/syft/serde/recursive.py | 18 +++++- .../src/syft/service/action/action_service.py | 3 +- .../syft/src/syft/service/policy/policy.py | 55 ++++++++++++++----- 3 files changed, 58 insertions(+), 18 deletions(-) diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index 4c438975245..0d8f740f622 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -385,11 +385,22 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: canonical_name = proto.canonicalName version = getattr(proto, "version", -1) + # if "RepeatedCallPolicy" in canonical_name: + # import ipdb + # ipdb.set_trace() + if not SyftObjectRegistry.has_serde_class(canonical_name, version): + # import ipdb + # ipdb.set_trace() + from ..server.server import CODE_RELOADER + + for load_user_code in CODE_RELOADER.values(): + load_user_code() # third party - raise Exception( - f"proto2obj: {canonical_name} version {version} not in SyftObjectRegistry" - ) + if not SyftObjectRegistry.has_serde_class(canonical_name, version): + raise Exception( + f"proto2obj: {canonical_name} version {version} not in SyftObjectRegistry" + ) # TODO: 🐉 sort this out, basically sometimes the syft.user classes are not in the # module name space in sub-processes or threads even though they are loaded on start @@ -434,6 +445,7 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: if hasattr(class_type, "serde_constructor"): return class_type.serde_constructor(kwargs) + if issubclass(class_type, Enum) and "value" in kwargs: obj = class_type.__new__(class_type, kwargs["value"]) elif issubclass(class_type, BaseModel): diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index ab82c80a2b8..104656f2102 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -509,7 +509,8 @@ def _user_code_execute( mock_obj=result_action_object_mock, ) except Exception as e: - return Err(f"_user_code_execute failed. {e}") + import traceback + return Err(f"_user_code_execute failed. {e}, {traceback.format_exc()}") return Ok(result_action_object) def set_result_to_store( diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index ba1ae048f95..3c141960046 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -1133,19 +1133,46 @@ def submit_policy_code_to_user_code() -> list[Callable]: ] -def add_class_to_user_module(klass: type, unique_name: str) -> type: - klass.__module__ = "syft.user" - klass.__name__ = unique_name - # syft absolute - import syft as sy +# def add_class_to_user_module(klass: type, unique_name: str) -> type: +# klass.__module__ = "syft.user" +# klass.__name__ = unique_name +# # syft absolute +# import syft as sy + +# if not hasattr(sy, "user"): +# user_module = types.ModuleType("user") +# sys.modules["syft"].user = user_module +# user_module = sy.user +# setattr(user_module, unique_name, klass) +# sys.modules["syft"].user = user_module +# return klass + +def register_policy_class(klass: type, unique_name: str) -> None: + nonrecursive=False + _serialize = None + _deserialize=None + attributes = [x for x in klass.model_fields.keys()] + exclude_attrs=[] + serde_overrides = {} + hash_exclude_attrs = [] + cls = klass + attribute_types = [] + version = 1 + + serde_attributes = ( + nonrecursive, + _serialize, + _deserialize, + attributes, + exclude_attrs, + serde_overrides, + hash_exclude_attrs, + cls, + attribute_types, + version, + ) - if not hasattr(sy, "user"): - user_module = types.ModuleType("user") - sys.modules["syft"].user = user_module - user_module = sy.user - setattr(user_module, unique_name, klass) - sys.modules["syft"].user = user_module - return klass + SyftObjectRegistry.register_cls(canonical_name=unique_name, version=version, serde_attributes=serde_attributes) def execute_policy_code(user_policy: UserPolicy) -> Any: @@ -1168,8 +1195,8 @@ def execute_policy_code(user_policy: UserPolicy) -> Any: except Exception: exec(user_policy.byte_code) # nosec policy_class = eval(user_policy.unique_name) # nosec - - policy_class = add_class_to_user_module(policy_class, user_policy.unique_name) + + register_policy_class(policy_class, user_policy.unique_name) sys.stdout = stdout_ sys.stderr = stderr_ From 4a42037f8173f2a5f85aba4b2fc5f2f2e615ca07 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Fri, 19 Jul 2024 15:18:03 +0200 Subject: [PATCH 2/5] cleanup --- packages/syft/src/syft/serde/recursive.py | 5 ----- packages/syft/src/syft/service/policy/policy.py | 14 -------------- 2 files changed, 19 deletions(-) diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index 0d8f740f622..8e56148b938 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -385,13 +385,8 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: canonical_name = proto.canonicalName version = getattr(proto, "version", -1) - # if "RepeatedCallPolicy" in canonical_name: - # import ipdb - # ipdb.set_trace() if not SyftObjectRegistry.has_serde_class(canonical_name, version): - # import ipdb - # ipdb.set_trace() from ..server.server import CODE_RELOADER for load_user_code in CODE_RELOADER.values(): diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index 3c141960046..560e01f27b6 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -1133,20 +1133,6 @@ def submit_policy_code_to_user_code() -> list[Callable]: ] -# def add_class_to_user_module(klass: type, unique_name: str) -> type: -# klass.__module__ = "syft.user" -# klass.__name__ = unique_name -# # syft absolute -# import syft as sy - -# if not hasattr(sy, "user"): -# user_module = types.ModuleType("user") -# sys.modules["syft"].user = user_module -# user_module = sy.user -# setattr(user_module, unique_name, klass) -# sys.modules["syft"].user = user_module -# return klass - def register_policy_class(klass: type, unique_name: str) -> None: nonrecursive=False _serialize = None From 897c89353b07e72ddc55f252fe811ed4d3cbc4b2 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Fri, 19 Jul 2024 15:41:58 +0200 Subject: [PATCH 3/5] lint --- packages/syft/src/syft/serde/recursive.py | 3 +-- packages/syft/src/syft/service/action/action_service.py | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index 8e56148b938..33bf94c8d4f 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -385,8 +385,8 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: canonical_name = proto.canonicalName version = getattr(proto, "version", -1) - if not SyftObjectRegistry.has_serde_class(canonical_name, version): + # relative from ..server.server import CODE_RELOADER for load_user_code in CODE_RELOADER.values(): @@ -440,7 +440,6 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any: if hasattr(class_type, "serde_constructor"): return class_type.serde_constructor(kwargs) - if issubclass(class_type, Enum) and "value" in kwargs: obj = class_type.__new__(class_type, kwargs["value"]) elif issubclass(class_type, BaseModel): diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 104656f2102..a4cfd4e369a 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -509,7 +509,9 @@ def _user_code_execute( mock_obj=result_action_object_mock, ) except Exception as e: + # stdlib import traceback + return Err(f"_user_code_execute failed. {e}, {traceback.format_exc()}") return Ok(result_action_object) From 3624b694228f1ed785e1522ce844f4d93ad1ba6f Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Fri, 19 Jul 2024 15:43:05 +0200 Subject: [PATCH 4/5] undo debug --- packages/syft/src/syft/service/action/action_service.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index a4cfd4e369a..ab82c80a2b8 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -509,10 +509,7 @@ def _user_code_execute( mock_obj=result_action_object_mock, ) except Exception as e: - # stdlib - import traceback - - return Err(f"_user_code_execute failed. {e}, {traceback.format_exc()}") + return Err(f"_user_code_execute failed. {e}") return Ok(result_action_object) def set_result_to_store( From 251f8f6f37298bd7c3b7922b7a47f04a59611c6c Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Fri, 19 Jul 2024 16:06:56 +0200 Subject: [PATCH 5/5] lint --- .../syft/src/syft/service/policy/policy.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index 560e01f27b6..4bf96a58c5a 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -12,7 +12,6 @@ from inspect import Signature from io import StringIO import sys -import types from typing import Any from typing import ClassVar @@ -1134,15 +1133,15 @@ def submit_policy_code_to_user_code() -> list[Callable]: def register_policy_class(klass: type, unique_name: str) -> None: - nonrecursive=False + nonrecursive = False _serialize = None - _deserialize=None - attributes = [x for x in klass.model_fields.keys()] - exclude_attrs=[] - serde_overrides = {} - hash_exclude_attrs = [] + _deserialize = None + attributes = list(klass.model_fields.keys()) + exclude_attrs: list = [] + serde_overrides: dict = {} + hash_exclude_attrs: list = [] cls = klass - attribute_types = [] + attribute_types: list = [] version = 1 serde_attributes = ( @@ -1158,7 +1157,9 @@ def register_policy_class(klass: type, unique_name: str) -> None: version, ) - SyftObjectRegistry.register_cls(canonical_name=unique_name, version=version, serde_attributes=serde_attributes) + SyftObjectRegistry.register_cls( + canonical_name=unique_name, version=version, serde_attributes=serde_attributes + ) def execute_policy_code(user_policy: UserPolicy) -> Any: @@ -1181,7 +1182,7 @@ def execute_policy_code(user_policy: UserPolicy) -> Any: except Exception: exec(user_policy.byte_code) # nosec policy_class = eval(user_policy.unique_name) # nosec - + register_policy_class(policy_class, user_policy.unique_name) sys.stdout = stdout_