Skip to content

Commit

Permalink
dev: update visualization code
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Apr 14, 2024
1 parent ba00e2b commit 78ad715
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/evox/monitors/eval_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def get_best_fitness(self):
def get_history(self):
return [self.opt_direction * fit for fit in self.fitness_history]

def plot(self, **kwargs):
def plot(self, problem_pf=None, **kwargs):
if not self.fitness_history:
warnings.warn("No fitness history recorded, return None")
return
Expand All @@ -182,9 +182,9 @@ def plot(self, **kwargs):
if n_objs == 1:
return plot.plot_obj_space_1d(self.fitness_history, **kwargs)
elif n_objs == 2:
return plot.plot_obj_space_2d(self.fitness_history, **kwargs)
return plot.plot_obj_space_2d(self.fitness_history, problem_pf, **kwargs)
elif n_objs == 3:
return plot.plot_obj_space_3d(self.fitness_history, **kwargs)
return plot.plot_obj_space_3d(self.fitness_history, problem_pf, **kwargs)
else:
warnings.warn("Not supported yet.")

Expand Down
7 changes: 3 additions & 4 deletions src/evox/monitors/pop_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def post_step(self, state):
fitness = jax.device_put(fitness, self.host)
self.fitness_history.append(fitness)

def plot(self, **kwargs):
def plot(self, problem_pf=None, **kwargs):
if not self.fitness_history:
warnings.warn("No fitness history recorded, return None")
return
Expand All @@ -79,9 +79,9 @@ def plot(self, **kwargs):
if n_objs == 1:
return plot.plot_obj_space_1d(self.fitness_history, **kwargs)
elif n_objs == 2:
return plot.plot_obj_space_2d(self.fitness_history, **kwargs)
return plot.plot_obj_space_2d(self.fitness_history, problem_pf, **kwargs)
elif n_objs == 3:
return plot.plot_obj_space_3d(self.fitness_history, **kwargs)
return plot.plot_obj_space_3d(self.fitness_history, problem_pf, **kwargs)
else:
warnings.warn("Not supported yet.")

Expand All @@ -90,4 +90,3 @@ def get_population_history(self):

def get_fitness_history(self):
return self.fitness_history

11 changes: 10 additions & 1 deletion src/evox/vis_tools/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,16 @@ def plot_obj_space_1d(fitness_history):
go.Scatter(x=generation, y=max_fitness, mode="lines", name="Max"),
go.Scatter(x=generation, y=median_fitness, mode="lines", name="Median"),
go.Scatter(x=generation, y=avg_fitness, mode="lines", name="Average"),
]
],
layout=go.Layout(
legend={
"x": 1,
"y": 1,
"xanchor": "auto",
"xanchor": "auto",
},
margin={"l": 0, "r": 0, "t": 0, "b": 0},
),
)

return fig
Expand Down

0 comments on commit 78ad715

Please sign in to comment.