Skip to content

Commit

Permalink
Add multi objective
Browse files Browse the repository at this point in the history
  • Loading branch information
YamLyubov committed Oct 31, 2023
1 parent 5161abd commit a9fd28c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
32 changes: 22 additions & 10 deletions golem/core/tuning/iopt_tuner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from dataclasses import dataclass, field
from datetime import timedelta
from random import choice
Expand All @@ -16,6 +17,7 @@
from golem.core.optimisers.objective import ObjectiveEvaluate
from golem.core.tuning.search_space import SearchSpace, get_node_operation_parameter_label, convert_parameters
from golem.core.tuning.tuner_interface import BaseTuner, DomainGraphForTune
from golem.utilities.data_structures import ensure_wrapped_in_sequence


@dataclass
Expand Down Expand Up @@ -49,12 +51,13 @@ def from_parameters_dicts(float_parameters_dict: Optional[Dict[str, List]] = Non
class GolemProblem(Problem, Generic[DomainGraphForTune]):
def __init__(self, graph: DomainGraphForTune,
objective_evaluate: ObjectiveEvaluate,
problem_parameters: IOptProblemParameters):
problem_parameters: IOptProblemParameters,
objectives_number: int = 1):
super().__init__()
self.objective_evaluate = objective_evaluate
self.graph = graph

self.number_of_objectives = 1
self.number_of_objectives = objectives_number
self.number_of_constraints = 0

self.discrete_variable_names = problem_parameters.discrete_parameters_names
Expand Down Expand Up @@ -146,16 +149,18 @@ def tune(self, graph: DomainGraphForTune, show_progress: bool = True) -> DomainG
no_parameters_to_optimize = (not problem_parameters.discrete_parameters_names and
not problem_parameters.float_parameters_names)
self.init_check(graph)
objectives_number = len(ensure_wrapped_in_sequence(self.init_metric))
is_multi_objective = objectives_number > 1

if no_parameters_to_optimize:
self._stop_tuning_with_message(f'Graph "{graph.graph_description}" has no parameters to optimize')
final_graph = graph
tuned_graphs = graph
else:
if initial_parameters:
initial_point = Point(**initial_parameters)
self.solver_parameters.start_point = initial_point

problem = GolemProblem(graph, self.objective_evaluate, problem_parameters)
problem = GolemProblem(graph, self.objective_evaluate, problem_parameters, objectives_number)
solver = Solver(problem, parameters=self.solver_parameters)

if show_progress:
Expand All @@ -164,14 +169,21 @@ def tune(self, graph: DomainGraphForTune, show_progress: bool = True) -> DomainG

solver.solve()
solution = solver.get_results()
best_point = solution.best_trials[0].point
best_parameters = problem.get_parameters_dict_from_iopt_point(best_point)
final_graph = self.set_arg_graph(graph, best_parameters)

self.was_tuned = True
if not is_multi_objective:
best_point = solution.best_trials[0].point
best_parameters = problem.get_parameters_dict_from_iopt_point(best_point)
tuned_graphs = self.set_arg_graph(graph, best_parameters)
self.was_tuned = True
else:
tuned_graphs = []
for best_trial in solution.best_trials:
best_parameters = problem.get_parameters_dict_from_iopt_point(best_trial.point)
tuned_graph = self.set_arg_graph(deepcopy(graph), best_parameters)
tuned_graphs.append(tuned_graph)
self.was_tuned = True

# Validate if optimisation did well
graph = self.final_check(final_graph)
graph = self.final_check(tuned_graphs, is_multi_objective)
final_graph = self.adapter.restore(graph)
return final_graph

Expand Down
2 changes: 1 addition & 1 deletion test/unit/tuning/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_node_tuning(search_space, graph):
assert tuner.init_metric >= tuner.obtained_metric


@pytest.mark.parametrize('tuner_cls', [OptunaTuner])
@pytest.mark.parametrize('tuner_cls', [OptunaTuner, IOptTuner])
@pytest.mark.parametrize('init_graph, adapter, obj_eval',
[(mock_graph_with_params(), MockAdapter(),
MockObjectiveEvaluate(Objective({'sum_metric': ParamsSumMetric.get_value,
Expand Down

0 comments on commit a9fd28c

Please sign in to comment.