Skip to content

Commit

Permalink
Merge pull request #9066 from OpenMined/fix-code-reloader
Browse files Browse the repository at this point in the history
add code reloader
  • Loading branch information
shubham3121 authored Jul 25, 2024
2 parents cec00dc + 7461673 commit 85a4d86
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
12 changes: 9 additions & 3 deletions packages/syft/src/syft/serde/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,16 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any:
version = getattr(proto, "version", -1)

if not SyftObjectRegistry.has_serde_class(canonical_name, version):
# relative
from ..server.server import CODE_RELOADER

for load_user_code in CODE_RELOADER.values():
load_user_code()
# third party
raise Exception(
f"proto2obj: {canonical_name} version {version} not in SyftObjectRegistry"
)
if not SyftObjectRegistry.has_serde_class(canonical_name, version):
raise Exception(
f"proto2obj: {canonical_name} version {version} not in SyftObjectRegistry"
)

# TODO: 🐉 sort this out, basically sometimes the syft.user classes are not in the
# module name space in sub-processes or threads even though they are loaded on start
Expand Down
42 changes: 28 additions & 14 deletions packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from inspect import Signature
from io import StringIO
import sys
import types
from typing import Any
from typing import ClassVar

Expand Down Expand Up @@ -1133,19 +1132,34 @@ def submit_policy_code_to_user_code() -> list[Callable]:
]


def add_class_to_user_module(klass: type, unique_name: str) -> type:
klass.__module__ = "syft.user"
klass.__name__ = unique_name
# syft absolute
import syft as sy
def register_policy_class(klass: type, unique_name: str) -> None:
nonrecursive = False
_serialize = None
_deserialize = None
attributes = list(klass.model_fields.keys())
exclude_attrs: list = []
serde_overrides: dict = {}
hash_exclude_attrs: list = []
cls = klass
attribute_types: list = []
version = 1

serde_attributes = (
nonrecursive,
_serialize,
_deserialize,
attributes,
exclude_attrs,
serde_overrides,
hash_exclude_attrs,
cls,
attribute_types,
version,
)

if not hasattr(sy, "user"):
user_module = types.ModuleType("user")
sys.modules["syft"].user = user_module
user_module = sy.user
setattr(user_module, unique_name, klass)
sys.modules["syft"].user = user_module
return klass
SyftObjectRegistry.register_cls(
canonical_name=unique_name, version=version, serde_attributes=serde_attributes
)


def execute_policy_code(user_policy: UserPolicy) -> Any:
Expand All @@ -1169,7 +1183,7 @@ def execute_policy_code(user_policy: UserPolicy) -> Any:
exec(user_policy.byte_code) # nosec
policy_class = eval(user_policy.unique_name) # nosec

policy_class = add_class_to_user_module(policy_class, user_policy.unique_name)
register_policy_class(policy_class, user_policy.unique_name)

sys.stdout = stdout_
sys.stderr = stderr_
Expand Down

0 comments on commit 85a4d86

Please sign in to comment.