diff --git a/.github/workflows/pr-tests-stack.yml b/.github/workflows/pr-tests-stack.yml index 21ca78a02e2..1161909779f 100644 --- a/.github/workflows/pr-tests-stack.yml +++ b/.github/workflows/pr-tests-stack.yml @@ -622,7 +622,6 @@ jobs: tox -e migration.test pr-tests-migrations-k8s: - if: false # skipping this job for now strategy: max-parallel: 99 matrix: diff --git a/packages/syft/src/syft/client/datasite_client.py b/packages/syft/src/syft/client/datasite_client.py index 7553344ad5a..0129b4a17ac 100644 --- a/packages/syft/src/syft/client/datasite_client.py +++ b/packages/syft/src/syft/client/datasite_client.py @@ -416,8 +416,14 @@ def get_migration_data(self, include_blobs: bool = True) -> MigrationData: return res - def load_migration_data(self, path: str | Path) -> SyftSuccess: - migration_data = MigrationData.from_file(path) + def load_migration_data( + self, path_or_data: str | Path | MigrationData + ) -> SyftSuccess: + if isinstance(path_or_data, MigrationData): + migration_data = path_or_data + else: + migration_data = MigrationData.from_file(path_or_data) + migration_data._set_obj_location_(self.id, self.verify_key) if self.id != migration_data.server_uid: diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index 1e33755418e..4662307c235 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -642,15 +642,16 @@ def allowed_ids_only( public_message=f"Invalid server type for code submission: {context.server.server_type}" ) - server_identity = ServerIdentity( - server_name=context.server.name, - server_id=context.server.id, - verify_key=context.server.signing_key.verify_key, - ) - allowed_inputs = allowed_inputs.get(server_identity, {}) + allowed_inputs_for_server = None + for identity, inputs in allowed_inputs.items(): + if identity.server_id == context.server.id: + allowed_inputs_for_server = inputs + break + if allowed_inputs_for_server is None: + allowed_inputs_for_server = {} filtered_kwargs = {} - for key in allowed_inputs.keys(): + for key in allowed_inputs_for_server.keys(): if key in kwargs: value = kwargs[key] uid = value @@ -658,7 +659,7 @@ def allowed_ids_only( if not isinstance(uid, UID): uid = getattr(value, "id", None) - if uid != allowed_inputs[key]: + if uid != allowed_inputs_for_server[key]: raise SyftException( public_message=f"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}" )