Skip to content

Commit

Permalink
Merge pull request #237 from xela-95/fix/rigid-contacts-speedup
Browse files Browse the repository at this point in the history
Rigid contacts speedup improvements
  • Loading branch information
diegoferigo authored Sep 25, 2024
2 parents db72bf5 + effbb0e commit 100b60c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 21 deletions.
1 change: 1 addition & 0 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def collidable_point_dynamics(
data=data,
link_forces=link_forces,
joint_force_references=joint_force_references,
solver_tol=1e-3,
)

aux_data = dict()
Expand Down
39 changes: 18 additions & 21 deletions src/jaxsim/rbda/contacts/rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def compute_contact_forces(
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
regularization_term: jtp.FloatLike = 1e-6,
solver_tol: jtp.FloatLike = 1e-3,
) -> tuple[jtp.Vector, tuple[Any, ...]]:
"""
Compute the contact forces.
Expand Down Expand Up @@ -257,6 +258,9 @@ def compute_contact_forces(
M = js.model.free_floating_mass_matrix(model=model, data=data)
J_WC = js.contact.jacobian(model=model, data=data)
W_H_C = js.contact.transforms(model=model, data=data)
J̇_WC_BW = js.contact.jacobian_derivative(model=model, data=data)
BW_ν = data.generalized_velocity()

terrain_height = jax.vmap(self.terrain.height)(position[:, 0], position[:, 1])
n_collidable_points = model.kin_dyn_parameters.contact_parameters.point.shape[0]

Expand Down Expand Up @@ -295,9 +299,10 @@ def compute_contact_forces(
)

free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
model,
data,
BW_ν̇_free,
BW_nu=BW_ν,
BW_nu_dot=BW_ν̇_free,
CW_J_WC_BW=J_WC,
CW_J_dot_WC_BW=J̇_WC_BW,
).flatten()

# Compute stabilization term
Expand Down Expand Up @@ -325,7 +330,9 @@ def compute_contact_forces(
b = jnp.zeros((0,))

# Solve the optimization problem
solution, *_ = qpax.solve_qp(Q=Q, q=q, A=A, b=b, G=G, h=h_bounds)
solution, *_ = qpax.solve_qp(
Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, solver_tol=solver_tol
)

f_C_lin = solution.reshape(-1, 3)

Expand Down Expand Up @@ -399,24 +406,14 @@ def _compute_ineq_bounds(n_collidable_points: jtp.FloatLike) -> jtp.Vector:

@staticmethod
def _linear_acceleration_of_collidable_points(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
mixed_nu_dot: jtp.ArrayLike,
BW_nu: jtp.ArrayLike,
BW_nu_dot: jtp.ArrayLike,
CW_J_WC_BW: jtp.MatrixLike,
CW_J_dot_WC_BW: jtp.MatrixLike,
) -> jtp.Matrix:
with data.switch_velocity_representation(VelRepr.Mixed):
CW_J_WC_BW = js.contact.jacobian(
model=model,
data=data,
output_vel_repr=VelRepr.Mixed,
)
CW_J̇_WC_BW = js.contact.jacobian_derivative(
model=model,
data=data,
output_vel_repr=VelRepr.Mixed,
)

BW_ν = data.generalized_velocity()
BW_ν̇ = mixed_nu_dot
CW_J̇_WC_BW = CW_J_dot_WC_BW
BW_ν = BW_nu
BW_ν̇ = BW_nu_dot

CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇
CW_a_WC = CW_a_WC.reshape(-1, 6)
Expand Down

0 comments on commit 100b60c

Please sign in to comment.