Skip to content

Commit

Permalink
Small refactoring in DPEngine (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvadym authored Jul 31, 2023
1 parent 9413b3c commit 0dee86d
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions pipeline_dp/dp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.
"""DP aggregations."""
import functools
from typing import Any, Optional, Tuple
from typing import Any, Callable, Optional, Tuple

import pipeline_dp
from pipeline_dp import budget_accounting
from pipeline_dp import combiners
from pipeline_dp import contribution_bounders
from pipeline_dp import partition_selection
from pipeline_dp import pipeline_functions
from pipeline_dp import report_generator
from pipeline_dp import sampling_utils
from pipeline_dp.dataset_histograms import computing_histograms
Expand Down Expand Up @@ -109,14 +110,16 @@ def _aggregate(self, col, params: pipeline_dp.AggregateParams,
else:
combiner = self._create_compound_combiner(params)

col = self._extract_columns(col, data_extractors)
# col : (privacy_id, partition_key, value)
if (public_partitions is not None and
not params.public_partitions_already_filtered):
col = self._drop_partitions(col, public_partitions, data_extractors)
col = self._drop_partitions(col,
public_partitions,
partition_extractor=lambda row: row[1])
self._add_report_stage(
f"Public partition selection: dropped non public partitions")
if not params.contribution_bounds_already_enforced:
col = self._extract_columns(col, data_extractors)
# col : (privacy_id, partition_key, value)
contribution_bounder = self._create_contribution_bounder(params)
col = contribution_bounder.bound_contributions(
col, params, self._backend, self._current_report_generator,
Expand All @@ -127,11 +130,8 @@ def _aggregate(self, col, params: pipeline_dp.AggregateParams,
"Drop privacy id")
# col : (partition_key, accumulator)
else:
# Extract the columns.
col = self._backend.map(
col, lambda row: (data_extractors.partition_extractor(row),
data_extractors.value_extractor(row)),
"Extract (partition_key, value))")
col = self._backend.map(col, lambda row: row[1:],
"Remove privacy_id")
# col : (partition_key, value)

col = self._backend.map_values(
Expand Down Expand Up @@ -273,15 +273,13 @@ def sample_unique_elements_fn(pid_and_pks):

return col

def _drop_partitions(self, col, partitions,
data_extractors: pipeline_dp.DataExtractors):
"""Drops partitions in `col` which are not in `public_partitions`."""
col = self._backend.map(
col, lambda row: (data_extractors.partition_extractor(row), row),
"Extract partition id")
def _drop_partitions(self, col, partitions, partition_extractor: Callable):
"""Drops partitions in `col` which are not in `partitions`."""
col = pipeline_functions.key_by(self._backend, col, partition_extractor,
"Key by partition")
col = self._backend.filter_by_key(col, partitions,
"Filtering out partitions")
return self._backend.map_tuple(col, lambda k, v: v, "Drop key")
return self._backend.values(col, "Drop key")

def _add_empty_public_partitions(self, col, public_partitions,
aggregator_fn):
Expand Down Expand Up @@ -381,10 +379,16 @@ def _create_contribution_bounder(
def _extract_columns(self, col,
data_extractors: pipeline_dp.DataExtractors):
"""Extract columns using data_extractors."""
if data_extractors.privacy_id_extractor is None:
# This is the case when contribution bounding already enforced and
# no need to extract privacy_id.
privacy_id_extractor = lambda row: None
else:
privacy_id_extractor = data_extractors.privacy_id_extractor
return self._backend.map(
col, lambda row: (data_extractors.privacy_id_extractor(row),
data_extractors.partition_extractor(row),
data_extractors.value_extractor(row)),
col, lambda row:
(privacy_id_extractor(row), data_extractors.partition_extractor(
row), data_extractors.value_extractor(row)),
"Extract (privacy_id, partition_key, value))")

def _check_aggregate_params(self,
Expand Down Expand Up @@ -460,7 +464,8 @@ def calculate_private_contribution_bounds(
col, params, data_extractors)

if not partitions_already_filtered:
col = self._drop_partitions(col, partitions, data_extractors)
col = self._drop_partitions(col, partitions,
data_extractors.partition_extractor)

histograms = computing_histograms.compute_dataset_histograms(
col, data_extractors, self._backend)
Expand Down

0 comments on commit 0dee86d

Please sign in to comment.