diff --git a/examples-gallery/plot_betts2003.py b/examples-gallery/plot_betts2003.py index a8baa0f..11fcfea 100644 --- a/examples-gallery/plot_betts2003.py +++ b/examples-gallery/plot_betts2003.py @@ -119,6 +119,7 @@ def obj_grad(free): plt.tight_layout() +# %% fig_y2, axes_y2 = plt.subplots(3, 1) axes_y2[0].plot(time, y2_m, '.k', @@ -135,4 +136,10 @@ def obj_grad(free): plt.tight_layout() +# %% +prob.plot_constraint_violations(solution) + +# %% +prob.plot_trajectories(solution) + plt.show() diff --git a/opty/direct_collocation.py b/opty/direct_collocation.py index 62f4864..2a47ba9 100644 --- a/opty/direct_collocation.py +++ b/opty/direct_collocation.py @@ -391,7 +391,15 @@ def plot_trajectories(self, vector, axes=None): self.collocator.num_input_trajectories) traj_syms = (self.collocator.state_symbols + self.collocator.input_trajectories) - trajectories = np.vstack((state_traj, input_traj)) + + trajectories = state_traj + + if self.collocator.num_known_input_trajectories > 0: + known_traj = list(self.collocator.known_trajectory_map.values()) + trajectories = np.vstack((trajectories, known_traj)) + + if self.collocator.num_unknown_input_trajectories > 0: + trajectories = np.vstack((trajectories, input_traj)) if axes is None: fig, axes = plt.subplots(num_axes, 1, sharex=True,