forked from Samorange1/SINDy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mppi.py
368 lines (316 loc) · 15.4 KB
/
mppi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
##############################################
# This file is borrowed from UM-ARM-Lab/pytorch_mppi
# You can find more information at https://github.com/UM-ARM-Lab/pytorch_mppi
##############################################
import functools
import logging
import time
import numpy as np
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
logger = logging.getLogger(__name__)
def _ensure_non_zero(cost, beta, factor):
return torch.exp(-factor * (cost - beta))
def is_tensor_like(x):
return torch.is_tensor(x) or type(x) is np.ndarray
def squeeze_n(v, n_squeeze):
for _ in range(n_squeeze):
v = v.squeeze(0)
return v
# from arm_pytorch_utilities, standalone since that package is not on pypi yet
def handle_batch_input(n):
def _handle_batch_input(func):
"""For func that expect 2D input, handle input that have more than 2 dimensions by flattening them temporarily"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# assume inputs that are tensor-like have compatible shapes and is represented by the first argument
batch_dims = []
for arg in args:
if is_tensor_like(arg):
if len(arg.shape) > n:
# last dimension is type dependent; all previous ones are batches
batch_dims = arg.shape[:-(n - 1)]
break
elif len(arg.shape) < n:
n_batch_dims_to_add = n - len(arg.shape)
batch_ones_to_add = [1] * n_batch_dims_to_add
args = [v.view(*batch_ones_to_add, *v.shape) if is_tensor_like(v) else v for v in args]
ret = func(*args, **kwargs)
if isinstance(ret, tuple):
ret = [squeeze_n(v, n_batch_dims_to_add) if is_tensor_like(v) else v for v in ret]
return ret
else:
if is_tensor_like(ret):
return squeeze_n(ret, n_batch_dims_to_add)
else:
return ret
# no batches; just return normally
if not batch_dims:
return func(*args, **kwargs)
# reduce all batch dimensions down to the first one
args = [v.view(-1, *v.shape[-(n - 1):]) if (is_tensor_like(v) and len(v.shape) > 2) else v for v in args]
ret = func(*args, **kwargs)
# restore original batch dimensions; keep variable dimension (nx)
if type(ret) is tuple:
ret = [v if (not is_tensor_like(v) or len(v.shape) == 0) else (
v.view(*batch_dims, *v.shape[-(n - 1):]) if len(v.shape) == n else v.view(*batch_dims)) for v in
ret]
else:
if is_tensor_like(ret):
if len(ret.shape) == n:
ret = ret.view(*batch_dims, *ret.shape[-(n - 1):])
else:
ret = ret.view(*batch_dims)
return ret
return wrapper
return _handle_batch_input
class MPPI():
"""
Model Predictive Path Integral control
This implementation batch samples the trajectories and so scales well with the number of samples K.
Implemented according to algorithm 2 in Williams et al., 2017
'Information Theoretic MPC for Model-Based Reinforcement Learning',
based off of https://github.com/ferreirafabio/mppi_pendulum
"""
def __init__(self, dynamics, running_cost, nx, noise_sigma, num_samples=100, horizon=15, device="cpu",
terminal_state_cost=None,
lambda_=1.,
noise_mu=None,
u_min=None,
u_max=None,
u_init=None,
U_init=None,
u_scale=1,
u_per_command=1,
step_dependent_dynamics=False,
rollout_samples=1,
rollout_var_cost=0,
rollout_var_discount=0.95,
sample_null_action=False,
noise_abs_cost=False):
"""
:param dynamics: function(state, action) -> next_state (K x nx) taking in batch state (K x nx) and action (K x nu)
:param running_cost: function(state, action) -> cost (K) taking in batch state and action (same as dynamics)
:param nx: state dimension
:param noise_sigma: (nu x nu) control noise covariance (assume v_t ~ N(u_t, noise_sigma))
:param num_samples: K, number of trajectories to sample
:param horizon: T, length of each trajectory
:param device: pytorch device
:param terminal_state_cost: function(state) -> cost (K x 1) taking in batch state
:param lambda_: temperature, positive scalar where larger values will allow more exploration
:param noise_mu: (nu) control noise mean (used to bias control samples); defaults to zero mean
:param u_min: (nu) minimum values for each dimension of control to pass into dynamics
:param u_max: (nu) maximum values for each dimension of control to pass into dynamics
:param u_init: (nu) what to initialize new end of trajectory control to be; defeaults to zero
:param U_init: (T x nu) initial control sequence; defaults to noise
:param step_dependent_dynamics: whether the passed in dynamics needs horizon step passed in (as 3rd arg)
:param rollout_samples: M, number of state trajectories to rollout for each control trajectory
(should be 1 for deterministic dynamics and more for models that output a distribution)
:param rollout_var_cost: Cost attached to the variance of costs across trajectory rollouts
:param rollout_var_discount: Discount of variance cost over control horizon
:param sample_null_action: Whether to explicitly sample a null action (bad for starting in a local minima)
:param noise_abs_cost: Whether to use the absolute value of the action noise to avoid bias when all states have the same cost
"""
self.d = device
self.dtype = noise_sigma.dtype
self.K = num_samples # N_SAMPLES
self.T = horizon # TIMESTEPS
# dimensions of state and control
self.nx = nx
self.nu = 1 if len(noise_sigma.shape) == 0 else noise_sigma.shape[0]
self.lambda_ = lambda_
if noise_mu is None:
noise_mu = torch.zeros(self.nu, dtype=self.dtype)
if u_init is None:
u_init = torch.zeros_like(noise_mu)
# handle 1D edge case
if self.nu == 1:
noise_mu = noise_mu.view(-1)
noise_sigma = noise_sigma.view(-1, 1)
# bounds
self.u_min = u_min
self.u_max = u_max
self.u_scale = u_scale
self.u_per_command = u_per_command
# make sure if any of them is specified, both are specified
if self.u_max is not None and self.u_min is None:
if not torch.is_tensor(self.u_max):
self.u_max = torch.tensor(self.u_max)
self.u_min = -self.u_max
if self.u_min is not None and self.u_max is None:
if not torch.is_tensor(self.u_min):
self.u_min = torch.tensor(self.u_min)
self.u_max = -self.u_min
if self.u_min is not None:
self.u_min = self.u_min.to(device=self.d)
self.u_max = self.u_max.to(device=self.d)
self.noise_mu = noise_mu.to(self.d)
self.noise_sigma = noise_sigma.to(self.d)
self.noise_sigma_inv = torch.inverse(self.noise_sigma)
self.noise_dist = MultivariateNormal(self.noise_mu, covariance_matrix=self.noise_sigma)
# T x nu control sequence
self.U = U_init
self.u_init = u_init.to(self.d)
if self.U is None:
self.U = self.noise_dist.sample((self.T,))
self.step_dependency = step_dependent_dynamics
self.F = dynamics
self.running_cost = running_cost
self.terminal_state_cost = terminal_state_cost
self.sample_null_action = sample_null_action
self.noise_abs_cost = noise_abs_cost
self.state = None
# handling dynamics models that output a distribution (take multiple trajectory samples)
self.M = rollout_samples
self.rollout_var_cost = rollout_var_cost
self.rollout_var_discount = rollout_var_discount
# sampled results from last command
self.cost_total = None
self.cost_total_non_zero = None
self.omega = None
self.states = None
self.actions = None
@handle_batch_input(n=2)
def _dynamics(self, state, u, t):
return self.F(state, u, t) if self.step_dependency else self.F(state, u)
@handle_batch_input(n=2)
def _running_cost(self, state, u):
return self.running_cost(state, u)
def command(self, state):
"""
:param state: (nx) or (K x nx) current state, or samples of states (for propagating a distribution of states)
:returns action: (nu) best action
"""
# shift command 1 time step
self.U = torch.roll(self.U, -1, dims=0)
self.U[-1] = self.u_init
return self._command(state)
def _command(self, state):
if not torch.is_tensor(state):
state = torch.tensor(state)
self.state = state.to(dtype=self.dtype, device=self.d)
cost_total = self._compute_total_cost_batch()
beta = torch.min(cost_total)
self.cost_total_non_zero = _ensure_non_zero(cost_total, beta, 1 / self.lambda_)
eta = torch.sum(self.cost_total_non_zero)
self.omega = (1. / eta) * self.cost_total_non_zero
for t in range(self.T):
self.U[t] += torch.sum(self.omega.view(-1, 1) * self.noise[:, t], dim=0)
action = self.U[:self.u_per_command]
# reduce dimensionality if we only need the first command
if self.u_per_command == 1:
action = action[0]
return action
def reset(self):
"""
Clear controller state after finishing a trial
"""
self.U = self.noise_dist.sample((self.T,))
def _compute_rollout_costs(self, perturbed_actions):
K, T, nu = perturbed_actions.shape
assert nu == self.nu
cost_total = torch.zeros(K, device=self.d, dtype=self.dtype)
cost_samples = cost_total.repeat(self.M, 1)
cost_var = torch.zeros_like(cost_total)
# allow propagation of a sample of states (ex. to carry a distribution), or to start with a single state
if self.state.shape == (K, self.nx):
state = self.state
else:
state = self.state.view(1, -1).repeat(K, 1)
# rollout action trajectory M times to estimate expected cost
state = state.repeat(self.M, 1, 1)
states = []
actions = []
for t in range(T):
u = self.u_scale * perturbed_actions[:, t].repeat(self.M, 1, 1)
state = self._dynamics(state, u, t)
c = self._running_cost(state, u)
cost_samples += c
if self.M > 1:
cost_var += c.var(dim=0) * (self.rollout_var_discount ** t)
# Save total states/actions
states.append(state)
actions.append(u)
# Actions is K x T x nu
# States is K x T x nx
actions = torch.stack(actions, dim=-2)
states = torch.stack(states, dim=-2)
# action perturbation cost
if self.terminal_state_cost:
c = self.terminal_state_cost(states, actions)
cost_samples += c
cost_total += cost_samples.mean(dim=0)
cost_total += cost_var * self.rollout_var_cost
return cost_total, states, actions
def _compute_total_cost_batch(self):
# parallelize sampling across trajectories
# resample noise each time we take an action
self.noise = self.noise_dist.sample((self.K, self.T))
# broadcast own control to noise over samples; now it's K x T x nu
self.perturbed_action = self.U + self.noise
if self.sample_null_action:
self.perturbed_action[self.K - 1] = 0
# naively bound control
self.perturbed_action = self._bound_action(self.perturbed_action)
# bounded noise after bounding (some got cut off, so we don't penalize that in action cost)
self.noise = self.perturbed_action - self.U
if self.noise_abs_cost:
action_cost = self.lambda_ * torch.abs(self.noise) @ self.noise_sigma_inv
# NOTE: The original paper does self.lambda_ * torch.abs(self.noise) @ self.noise_sigma_inv, but this biases
# the actions with low noise if all states have the same cost. With abs(noise) we prefer actions close to the
# nomial trajectory.
else:
action_cost = self.lambda_ * self.noise @ self.noise_sigma_inv # Like original paper
self.cost_total, self.states, self.actions = self._compute_rollout_costs(self.perturbed_action)
self.actions /= self.u_scale
# action perturbation cost
perturbation_cost = torch.sum(self.U * action_cost, dim=(1, 2))
self.cost_total += perturbation_cost
return self.cost_total
def _bound_action(self, action):
if self.u_max is not None:
for t in range(self.T):
u = action[:, self._slice_control(t)]
cu = torch.max(torch.min(u, self.u_max), self.u_min)
action[:, self._slice_control(t)] = cu
return action
def _slice_control(self, t):
return slice(t * self.nu, (t + 1) * self.nu)
def get_rollouts(self, state, num_rollouts=1):
"""
:param state: either (nx) vector or (num_rollouts x nx) for sampled initial states
:param num_rollouts: Number of rollouts with same action sequence - for generating samples with stochastic
dynamics
:returns states: num_rollouts x T x nx vector of trajectories
"""
state = state.view(-1, self.nx)
if state.size(0) == 1:
state = state.repeat(num_rollouts, 1)
T = self.U.shape[0]
states = torch.zeros((num_rollouts, T + 1, self.nx), dtype=self.U.dtype, device=self.U.device)
states[:, 0] = state
for t in range(T):
states[:, t + 1] = self._dynamics(states[:, t].view(num_rollouts, -1),
self.u_scale * self.U[t].view(num_rollouts, -1), t)
return states[:, 1:]
def run_mppi(mppi, env, retrain_dynamics, retrain_after_iter=50, iter=1000, render=True):
dataset = torch.zeros((retrain_after_iter, mppi.nx + mppi.nu), dtype=mppi.U.dtype, device=mppi.d)
total_reward = 0
for i in range(iter):
state = env.state.copy()
command_start = time.perf_counter()
action = mppi.command(state)
elapsed = time.perf_counter() - command_start
s, r, _, _ = env.step(action.cpu().numpy())
total_reward += r
logger.debug("action taken: %.4f cost received: %.4f time taken: %.5fs", action, -r, elapsed)
if render:
env.render()
di = i % retrain_after_iter
if di == 0 and i > 0:
retrain_dynamics(dataset)
# don't have to clear dataset since it'll be overridden, but useful for debugging
dataset.zero_()
dataset[di, :mppi.nx] = torch.tensor(state, dtype=mppi.U.dtype)
dataset[di, mppi.nx:] = action
return total_reward, dataset