Skip to content

Commit

Permalink
Merge pull request #201 from moorepants/clean-create-obj
Browse files Browse the repository at this point in the history
Made create_objective_function signature match rest of opty, made node_time_interval required.
  • Loading branch information
moorepants committed Aug 4, 2024
2 parents 745e2ac + c741175 commit 8dc15d6
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 29 deletions.
2 changes: 1 addition & 1 deletion examples-gallery/plot_drone.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
specified_symbols,
tuple(),
num_nodes,
node_time_interval=interval_value)
interval_value)

# %%
# Specify the symbolic instance constraints.
Expand Down
2 changes: 1 addition & 1 deletion examples-gallery/plot_parallel_park.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@
specified_symbols,
tuple(),
num_nodes,
node_time_interval=interval_value)
interval_value)

# %%
# Specify the symbolic instance constraints, i.e. initial and end conditions.
Expand Down
2 changes: 1 addition & 1 deletion examples-gallery/plot_pendulum_swing_up_fixed_duration.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
obj, obj_grad = create_objective_function(obj_func, state_symbols,
specified_symbols, tuple(),
num_nodes,
node_time_interval=interval_value,
interval_value,
time_symbol=t)

# %%
Expand Down
2 changes: 1 addition & 1 deletion opty/tests/test_direct_collocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_pendulum():
obj_func = sym.Integral(T(t)**2, t)
obj, obj_grad = create_objective_function(
obj_func, state_symbols, specified_symbols, tuple(), num_nodes,
node_time_interval=interval_value, time_symbol=t)
interval_value, time_symbol=t)

# Specify the symbolic instance constraints, i.e. initial and end
# conditions.
Expand Down
60 changes: 35 additions & 25 deletions opty/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _print_Function(self, expr):
return name + '_'


def ccode(expr, assign_to=None, standard='c99', **settings):
def ccode(expr, assign_to=None, **settings):
"""Mimics SymPy's ccode, but uses our printer."""
return OptyC99CodePrinter(settings).doprint(expr, assign_to)

Expand Down Expand Up @@ -280,34 +280,39 @@ def parse_free(free, n, q, N, variable_duration=False):
return free_states, free_specified, free_constants


def create_objective_function(objective, state_symbols, input_symbols,
unknown_symbols, N, node_time_interval=1.0,
def create_objective_function(objective, state_symbols,
unknown_input_trajectories, unknown_parameters,
num_collocation_nodes, node_time_interval,
integration_method="backward euler",
time_symbol=me.dynamicsymbols._t):
"""Creates function to evaluate the objective and objective gradient.
Parameters
----------
objective : sympy.Expr
The objective function to be minimized. It should solely depend on the
states, inputs, and unknown symbols. Any known symbols should be
substituted beforehand. Additionally, the objective function can contain
non-nested indefinite integrals of time, e.g. ``Integral(f(t)**2, t)``.
state_symbols : iterable of symbols
The state variables.
input_symbols : iterable of symbols
The input variables.
unknown_symbols : iterable of symbols
The unknown parameters.
N : int
objective : Expr
Symbolic objective function to be minimized. It should solely depend on
the states, unknown inputs, and unknown parameters. Any known inputs or
parameters should be substituted beforehand. Additionally, the
objective function can contain non-nested indefinite integrals of time,
e.g. ``Integral(f(t)**2, t)``.
state_symbols : iterable of Function()(t)
An iterable containing all ``n`` of the SymPy functions of time which
represent the states in the equations of motion.
unknown_input_trajectories : iterable of Function()(t)
An iterable containing all ``q`` of the SymPy functions of time which
represent the unknown input trajectories in the equations of motion.
unknown_parameters : iterable of Symbol
An iterable containing all ``r`` of the SymPy symbols which represent
the unknown parameters in the equations of motion.
num_collocation_nodes : int
Number of collocation nodes, i.e. the number of time steps.
node_time_interval : float
The value of the time interval. The default is 1.0, as this term only
appears in the objective function as a scaling factor.
integration_method : str, optional
The method used to integrate the system. The default is "backward
euler".
time_symbol : sympy.Symbol
The method used to integrate the system. The default is ``"backward
euler"``.
time_symbol : Symbol, optional
If not supplied, ``sympy.physics.mechanics.dynamicsymbols._t`` is used.
"""
Expand All @@ -328,7 +333,8 @@ def parse_expr(expr, in_integral=False):
return expr
if isinstance(expr, sm.Integral):
if in_integral:
raise NotImplementedError("Nested integrals are not supported.")
msg = "Nested integrals are not supported."
raise NotImplementedError(msg)
if expr.limits != ((time_symbol,),):
raise NotImplementedError(
"Only indefinite integrals of time are supported.")
Expand All @@ -337,12 +343,13 @@ def parse_expr(expr, in_integral=False):

# Parse function arguments
states = sm.ImmutableMatrix(state_symbols)
inputs = sm.ImmutableMatrix(sort_sympy(input_symbols))
params = sm.ImmutableMatrix(sort_sympy(unknown_symbols))
inputs = sm.ImmutableMatrix(sort_sympy(unknown_input_trajectories))
params = sm.ImmutableMatrix(sort_sympy(unknown_parameters))
if states.shape[1] > 1 or inputs.shape[1] > 1 or params.shape[1] > 1:
raise ValueError(
'The state, input, and unknown symbols must be column matrices.')
n, q = states.shape[0], inputs.shape[0]
N = num_collocation_nodes
i_idx, r_idx = n * N, (n + q) * N

# Compute analytical gradient of the objective function
Expand All @@ -356,9 +363,10 @@ def parse_expr(expr, in_integral=False):

# Replace zeros with an array of zeros, otherwise lambdify will return a
# scalar zero instead of an array of zeros.
objective_grad = tuple(
np.zeros(N) if grad == 0 else grad for grad in objective_grad[:n + q]
) + tuple(objective_grad[n + q:])
objective_grad = (tuple(np.zeros(N)
if grad == 0 else grad
for grad in objective_grad[:n + q]) +
tuple(objective_grad[n + q:]))

# Define evaluation functions based on the integration method
if integration_method == "backward euler":
Expand All @@ -368,6 +376,7 @@ def parse_expr(expr, in_integral=False):
objective_grad[:n + q], np.hstack((0, np.ones(N - 1))), False)
obj_grad_param_expr_eval = lambdify_function(
objective_grad[n + q:], np.hstack((0, np.ones(N - 1))), True)

def obj(free):
states = free[:i_idx].reshape((n, N))
inputs = free[i_idx:r_idx].reshape((q, N))
Expand All @@ -379,7 +388,7 @@ def obj_grad(free):
return np.hstack((
*obj_grad_time_expr_eval(states, inputs, free[r_idx:]),
obj_grad_param_expr_eval(states, inputs, free[r_idx:])
))
))

elif integration_method == "midpoint":
obj_expr_eval = lambdify_function(
Expand All @@ -389,6 +398,7 @@ def obj_grad(free):
False)
obj_grad_param_expr_eval = lambdify_function(
objective_grad[n + q:], np.ones(N - 1), True)

def obj(free):
states = free[:i_idx].reshape((n, N))
states_mid = 0.5 * (states[:, :-1] + states[:, 1:])
Expand Down

0 comments on commit 8dc15d6

Please sign in to comment.