From 00a541aba6e93813721822902f6eea7c3c608fba Mon Sep 17 00:00:00 2001 From: Vadym Doroshenko <53558779+dvadym@users.noreply.github.com> Date: Thu, 6 Jul 2023 16:49:58 +0200 Subject: [PATCH] Compute public partition summary (#465) --- analysis/dataset_summary.py | 108 +++++++++++++++++++++++++ analysis/tests/dataset_summary_test.py | 42 ++++++++++ 2 files changed, 150 insertions(+) create mode 100644 analysis/dataset_summary.py create mode 100644 analysis/tests/dataset_summary_test.py diff --git a/analysis/dataset_summary.py b/analysis/dataset_summary.py new file mode 100644 index 00000000..63766fc3 --- /dev/null +++ b/analysis/dataset_summary.py @@ -0,0 +1,108 @@ +# Copyright 2023 OpenMined. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains dataset summary and the computation of the summary.""" + +import pipeline_dp +import dataclasses +from typing import Iterable + + +@dataclasses.dataclass +class PublicPartitionsSummary: + num_dataset_public_partitions: int + num_dataset_non_public_partitions: int + num_empty_public_partitions: int + + +_DATASET_PUBLIC = 1 +_EMPTY_PUBLIC = 2 +_DATASET_NONPUBLIC = 3 + + +def compute_public_partitions_summary(col, backend: pipeline_dp.PipelineBackend, + extractors: pipeline_dp.DataExtractors, + public_partitions): + """Computes Public Partitions Summary from dataset and public partitions. + + Args: + col: the raw dataset. The collection where all elements are of the same + type. + backend: pipeline backend which corresponds to the type of 'col'. + extractors: functions that extract needed pieces of information + from elements of 'col'. + public_partitions: a collection of partition keys that will be present + in the result. If not provided, partitions will be selected in a DP + manner. + + Returns: + 1 element collection, which contains a PublicPartitionsSummary object. + """ + dataset_partitions = backend.map(col, extractors.partition_extractor, + "Extract partitions") + # (partition) + + dataset_partitions = backend.distinct(dataset_partitions, "Distinct") + # (partition) + + dataset_partitions = backend.map(dataset_partitions, lambda x: (x, True), + "Keyed by partition") + # (partition, is_from_dataset=True) + + public_partitions = backend.map(public_partitions, lambda x: (x, False), + "Keyed by partition") + # (partition, is_from_dataset = False) + + partitions = backend.flatten([dataset_partitions, public_partitions], + "flatten") + # (partition, is_from_dataset: bool) + + col = backend.group_by_key(partitions, "Group by Key") + + # (partition, Iterable) + + def process_fn(_, a: Iterable[bool]) -> int: + # a contains up to 2 booleans. + # True means that the partition is dataset. + # False means that the partition is in public partitions. + a = list(a) + if len(a) == 2: + return _DATASET_PUBLIC + if a[0]: + return _DATASET_NONPUBLIC + return _EMPTY_PUBLIC + + col = backend.map_tuple(col, process_fn, "Get Partition Type") + # (partition_type:int) + + col = backend.count_per_element(col, "Count partition types") + # (partition_type:int, count_partition_type:int) + + col = backend.to_list(col, "To list") + + # 1 element with list of tuples (partition_type, count_partition_type) + + def to_summary(partition_types_count: list) -> PublicPartitionsSummary: + num_dataset_public = num_dataset_non_public = num_empty_public = 0 + for type, count in partition_types_count: + if type == _DATASET_PUBLIC: + num_dataset_public = count + elif type == _DATASET_NONPUBLIC: + num_dataset_non_public = count + else: + num_empty_public = count + + return PublicPartitionsSummary(num_dataset_public, + num_dataset_non_public, num_empty_public) + + return backend.map(col, to_summary, "ToSummary") diff --git a/analysis/tests/dataset_summary_test.py b/analysis/tests/dataset_summary_test.py new file mode 100644 index 00000000..3dd28d44 --- /dev/null +++ b/analysis/tests/dataset_summary_test.py @@ -0,0 +1,42 @@ +# Copyright 2023 OpenMined. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the dataset summary.""" + +from absl.testing import absltest +from absl.testing import parameterized + +from analysis import dataset_summary +import pipeline_dp + + +class PublicDatasetSummaryTest(parameterized.TestCase): + + def test_compute_public_partitions_summary(self): + dataset = list(range(100)) + public_partitions = list(range(60, 121)) + extractors = pipeline_dp.DataExtractors( + partition_extractor=lambda x: x, privacy_id_extractor=lambda _: 0) + + summary = dataset_summary.compute_public_partitions_summary( + dataset, pipeline_dp.LocalBackend(), extractors, public_partitions) + + summary = list(summary)[0] + + self.assertEqual(summary.num_dataset_public_partitions, 40) + self.assertEqual(summary.num_dataset_non_public_partitions, 60) + self.assertEqual(summary.num_empty_public_partitions, 21) + + +if __name__ == '__main__': + absltest.main()