Skip to content

Commit

Permalink
Merge pull request OpenMined#9311 from OpenMined/eelco/notebook-tests…
Browse files Browse the repository at this point in the history
…-timeout

fix exactmatch policy for server with different name but same ID
  • Loading branch information
eelcovdw authored Sep 23, 2024
2 parents 08b4f53 + e49039b commit f5c0e8c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
1 change: 0 additions & 1 deletion .github/workflows/pr-tests-stack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,6 @@ jobs:
tox -e migration.test
pr-tests-migrations-k8s:
if: false # skipping this job for now
strategy:
max-parallel: 99
matrix:
Expand Down
10 changes: 8 additions & 2 deletions packages/syft/src/syft/client/datasite_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,14 @@ def get_migration_data(self, include_blobs: bool = True) -> MigrationData:

return res

def load_migration_data(self, path: str | Path) -> SyftSuccess:
migration_data = MigrationData.from_file(path)
def load_migration_data(
self, path_or_data: str | Path | MigrationData
) -> SyftSuccess:
if isinstance(path_or_data, MigrationData):
migration_data = path_or_data
else:
migration_data = MigrationData.from_file(path_or_data)

migration_data._set_obj_location_(self.id, self.verify_key)

if self.id != migration_data.server_uid:
Expand Down
17 changes: 9 additions & 8 deletions packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,23 +642,24 @@ def allowed_ids_only(
public_message=f"Invalid server type for code submission: {context.server.server_type}"
)

server_identity = ServerIdentity(
server_name=context.server.name,
server_id=context.server.id,
verify_key=context.server.signing_key.verify_key,
)
allowed_inputs = allowed_inputs.get(server_identity, {})
allowed_inputs_for_server = None
for identity, inputs in allowed_inputs.items():
if identity.server_id == context.server.id:
allowed_inputs_for_server = inputs
break
if allowed_inputs_for_server is None:
allowed_inputs_for_server = {}

filtered_kwargs = {}
for key in allowed_inputs.keys():
for key in allowed_inputs_for_server.keys():
if key in kwargs:
value = kwargs[key]
uid = value

if not isinstance(uid, UID):
uid = getattr(value, "id", None)

if uid != allowed_inputs[key]:
if uid != allowed_inputs_for_server[key]:
raise SyftException(
public_message=f"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}"
)
Expand Down

0 comments on commit f5c0e8c

Please sign in to comment.