Skip to content

Commit

Permalink
Custom combiner improvements (#463)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvadym authored Jun 29, 2023
1 parent dada1fe commit 3bb6e30
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 14 deletions.
4 changes: 3 additions & 1 deletion pipeline_dp/combiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,9 @@ def create_compound_combiner_with_custom_combiners(
budget_accountant: budget_accounting.BudgetAccountant,
custom_combiners: Iterable[CustomCombiner]) -> CompoundCombiner:
for combiner in custom_combiners:
combiner.set_aggregate_params(aggregate_params)
copy_aggregate_params = copy.copy(aggregate_params)
copy_aggregate_params.custom_combiners = None
combiner.set_aggregate_params(copy_aggregate_params)
combiner.request_budget(budget_accountant)

return CompoundCombiner(custom_combiners, return_named_tuple=False)
17 changes: 5 additions & 12 deletions pipeline_dp/private_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,9 @@ def add_input_for_private_output(self, accumulator, input) -> Any:
"""

@abc.abstractmethod
def extract_private_output(self, accumulator, budget: Any):
def extract_private_output(
self, accumulator, budget: Any,
aggregate_params: pipeline_dp.AggregateParams) -> Any:
"""Computes private output.
'budget' is the object which returned from 'request_budget()'.
Expand All @@ -540,15 +542,6 @@ def request_budget(
live in the driver process.
"""

def set_aggregate_params(self,
aggregate_params: pipeline_dp.AggregateParams):
"""Sets aggregate parameters
The custom combiner can optionally use it for own DP parameter
computations.
"""
self._aggregate_params = aggregate_params


class _CombineFnCombiner(pipeline_dp.CustomCombiner):

Expand All @@ -571,7 +564,7 @@ def merge_accumulators(self, accumulator1, accumulator2):
def compute_metrics(self, accumulator):
"""Computes and returns the result of aggregation."""
return self._private_combine_fn.extract_private_output(
accumulator, self._budget)
accumulator, self._budget, self._aggregate_params)

def explain_computation(self) -> str:
# TODO: implement
Expand All @@ -583,7 +576,7 @@ def request_budget(self,
budget_accountant)

def set_aggregate_params(self, aggregate_params):
self._private_combine_fn.set_aggregate_params(aggregate_params)
self._aggregate_params = aggregate_params


@dataclasses.dataclass
Expand Down
2 changes: 1 addition & 1 deletion tests/private_beam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def add_input_for_private_output(self, accumulator, input):
def merge_accumulators(self, accumulators):
return sum(accumulators)

def extract_private_output(self, accumulator, budget):
def extract_private_output(self, accumulator, budget, params):
return accumulator

def request_budget(self, budget_accountant):
Expand Down

0 comments on commit 3bb6e30

Please sign in to comment.