Skip to content

Commit

Permalink
polishings
Browse files Browse the repository at this point in the history
  • Loading branch information
dvadym committed Jul 6, 2023
1 parent f8e1cb6 commit f7d98fb
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 23 deletions.
57 changes: 38 additions & 19 deletions analysis/dataset_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
"""Contains dataset summary and the computation of the summary."""

import pipeline_dp
import dataclasses
from typing import Iterable
from enum import Enum


@dataclasses.dataclass
class PublicPartitionsSummary:
num_dataset_public: int
num_dataset_non_public: int
num_empty_public: int
num_dataset_public_partitions: int
num_dataset_non_public_partitions: int
num_empty_public_partitions: int


class _PartitionType(Enum):
DATASET_PUBLIC = 1
EMPTY_PUBLIC = 2
DATASET_NONPUBLIC = 3
_DATASET_PUBLIC = 1
_EMPTY_PUBLIC = 2
_DATASET_NONPUBLIC = 3


def compute_public_partitions_summary(col, backend: pipeline_dp.PipelineBackend,
extractors: pipeline_dp.DataExtractors,
public_partitions):
"""Computes Public Partitions Summary from dataset and public partitions.
Args:
col: collection where all elements are of the same type.
backend: pipeline backend which corresponds to the type of 'col'.
extractors: functions that extract needed pieces of information
from elements of 'col'.
public_partitions: a collection of partition keys that will be present
in the result. If not provided, partitions will be selected in a DP
manner.
Returns:
1 element collection, which contains a PublicPartitionsSummary object.
"""
dataset_partitions = backend.map(col, extractors.partition_extractor,
"Extract partitions")
# (partition)
Expand All @@ -43,11 +55,11 @@ def compute_public_partitions_summary(col, backend: pipeline_dp.PipelineBackend,
# (partition)

dataset_partitions = backend.map(dataset_partitions, lambda x: (x, True),
"Keyd by partition")
"Keyed by partition")
# (partition, is_from_dataset=True)

public_partitions = backend.map(public_partitions, lambda x: (x, False),
"Keyd by partition")
"Keyed by partition")
# (partition, is_from_dataset = False)

partitions = backend.flatten([dataset_partitions, public_partitions],
Expand All @@ -58,26 +70,33 @@ def compute_public_partitions_summary(col, backend: pipeline_dp.PipelineBackend,

# (partition, Iterable)

def process_fn(_, a: Iterable[bool]) -> _PartitionType:
def process_fn(_, a: Iterable[bool]) -> int:
# a contains up to 2 booleans.
# True means that the partition is dataset.
# False means that the partition is in public partitions.
a = list(a)
if len(a) == 2:
return _PartitionType.DATASET_PUBLIC
return _DATASET_PUBLIC
if a[0]:
return _PartitionType.DATASET_NONPUBLIC
return _PartitionType.EMPTY_PUBLIC
return _DATASET_NONPUBLIC
return _EMPTY_PUBLIC

col = backend.map_tuple(col, process_fn, "Get Partition Type")
# (partition_type:int)

col = backend.count_per_element(col, "Count partition types")
# (partition_type:int, count_partition_type:int)

col = backend.to_list(col, "To list")

def to_summary(A: list) -> PublicPartitionsSummary:
# 1 element with list of tuples (partition_type, count_partition_type)

def to_summary(partition_types_count: list) -> PublicPartitionsSummary:
num_dataset_public = num_dataset_non_public = num_empty_public = 0
for type, count in A:
if type == _PartitionType.DATASET_PUBLIC:
for type, count in partition_types_count:
if type == _DATASET_PUBLIC:
num_dataset_public = count
elif type == _PartitionType.DATASET_NONPUBLIC:
elif type == _DATASET_NONPUBLIC:
num_dataset_non_public = count
else:
num_empty_public = count
Expand Down
8 changes: 4 additions & 4 deletions analysis/tests/dataset_summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TODO"""
"""Tests for the dataset summary."""

from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -33,9 +33,9 @@ def test_compute_public_partitions_summary(self):

summary = list(summary)[0]

self.assertEqual(summary.num_dataset_public, 40)
self.assertEqual(summary.num_dataset_non_public, 60)
self.assertEqual(summary.num_empty_public, 21)
self.assertEqual(summary.num_dataset_public_partitions, 40)
self.assertEqual(summary.num_dataset_non_public_partititions, 60)
self.assertEqual(summary.num_empty_public_partitions, 21)


if __name__ == '__main__':
Expand Down

0 comments on commit f7d98fb

Please sign in to comment.