Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dvadym committed Jul 20, 2023
1 parent 08f09a8 commit ce2d2fd
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
10 changes: 5 additions & 5 deletions analysis/parameter_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,16 +349,16 @@ def _check_tune_args(options: TuneOptions, is_public_partitions: bool):
# Empty metrics means tuning for select_partitions.
if is_public_partitions:
# Empty metrics means that partition selection tuning is performed.
raise ValueError("empty metrics means tuning of partition selection"
" but public partitions were provided")
raise ValueError("Empty metrics means tuning of partition selection"
" but public partitions were provided.")
elif len(metrics) > 1:
raise NotImplementedError(
f"Tuning supports only one metrics, but {metrics} given.")
raise ValueError(
f"Tuning supports only one metric, but {metrics} given.")
else: # len(metrics) == 1
if metrics[0] not in [
pipeline_dp.Metrics.COUNT, pipeline_dp.Metrics.PRIVACY_ID_COUNT
]:
raise NotImplementedError(
raise ValueError(
f"Tuning is supported only for Count and Privacy id count, but {metrics[0]} given."
)

Expand Down
34 changes: 32 additions & 2 deletions analysis/tests/parameter_tuning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,38 @@ def test_tune_privacy_id_count(self):
self.assertEqual(utility_reports[0].metric_errors[0].metric,
pipeline_dp.Metrics.PRIVACY_ID_COUNT)

def test_tune_params_validation(self):
pass
@parameterized.named_parameters(
dict(testcase_name="Select partition and public partition",
error_msg="Empty metrics means tuning of partition selection but"
" public partitions were provided",
metrics=[],
is_public_partitions=True),
dict(testcase_name="Multiple metrics",
error_msg="Tuning supports only one metric",
metrics=[
pipeline_dp.Metrics.COUNT, pipeline_dp.Metrics.PRIVACY_ID_COUNT
],
is_public_partitions=True),
dict(
testcase_name="Mean is not supported",
error_msg="Tuning is supported only for Count and Privacy id count",
metrics=[pipeline_dp.Metrics.MEAN],
is_public_partitions=False),
)
def test_tune_params_validation(self, error_msg,
metrics: list[pipeline_dp.Metric],
is_public_partitions: bool):
tune_options = _get_tune_options()
tune_options.aggregate_params.metrics = metrics
contribution_histograms = histograms.DatasetHistograms(
None, None, None, None, None)
data_extractors = pipeline_dp.DataExtractors(
privacy_id_extractor=lambda _: 0, partition_extractor=lambda _: 0)
public_partitions = [1] if is_public_partitions else None
with self.assertRaisesRegex(ValueError, error_msg):
parameter_tuning.tune(input, pipeline_dp.LocalBackend(),
contribution_histograms, tune_options,
data_extractors, public_partitions)


if __name__ == '__main__':
Expand Down

0 comments on commit ce2d2fd

Please sign in to comment.