From 2af4b8b8c07f45e3e64689d7a2799a097d13a609 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Tue, 24 Sep 2024 11:26:52 +0200 Subject: [PATCH 1/2] Remove redundant calls toget kindyn quantities in `RigidContacts` Co-authored-by: diegoferigo --- src/jaxsim/rbda/contacts/rigid.py | 39 ++++++++++++++----------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index f052fa30..55862b4f 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -215,6 +215,7 @@ def compute_contact_forces( link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, regularization_term: jtp.FloatLike = 1e-6, + solver_tol: jtp.FloatLike = 1e-3, ) -> tuple[jtp.Vector, tuple[Any, ...]]: """ Compute the contact forces. @@ -257,6 +258,9 @@ def compute_contact_forces( M = js.model.free_floating_mass_matrix(model=model, data=data) J_WC = js.contact.jacobian(model=model, data=data) W_H_C = js.contact.transforms(model=model, data=data) + J̇_WC_BW = js.contact.jacobian_derivative(model=model, data=data) + BW_ν = data.generalized_velocity() + terrain_height = jax.vmap(self.terrain.height)(position[:, 0], position[:, 1]) n_collidable_points = model.kin_dyn_parameters.contact_parameters.point.shape[0] @@ -295,9 +299,10 @@ def compute_contact_forces( ) free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points( - model, - data, - BW_ν̇_free, + BW_nu=BW_ν, + BW_nu_dot=BW_ν̇_free, + CW_J_WC_BW=J_WC, + CW_J_dot_WC_BW=J̇_WC_BW, ).flatten() # Compute stabilization term @@ -325,7 +330,9 @@ def compute_contact_forces( b = jnp.zeros((0,)) # Solve the optimization problem - solution, *_ = qpax.solve_qp(Q=Q, q=q, A=A, b=b, G=G, h=h_bounds) + solution, *_ = qpax.solve_qp( + Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, solver_tol=solver_tol + ) f_C_lin = solution.reshape(-1, 3) @@ -399,24 +406,14 @@ def _compute_ineq_bounds(n_collidable_points: jtp.FloatLike) -> jtp.Vector: @staticmethod def _linear_acceleration_of_collidable_points( - model: js.model.JaxSimModel, - data: js.data.JaxSimModelData, - mixed_nu_dot: jtp.ArrayLike, + BW_nu: jtp.ArrayLike, + BW_nu_dot: jtp.ArrayLike, + CW_J_WC_BW: jtp.MatrixLike, + CW_J_dot_WC_BW: jtp.MatrixLike, ) -> jtp.Matrix: - with data.switch_velocity_representation(VelRepr.Mixed): - CW_J_WC_BW = js.contact.jacobian( - model=model, - data=data, - output_vel_repr=VelRepr.Mixed, - ) - CW_J̇_WC_BW = js.contact.jacobian_derivative( - model=model, - data=data, - output_vel_repr=VelRepr.Mixed, - ) - - BW_ν = data.generalized_velocity() - BW_ν̇ = mixed_nu_dot + CW_J̇_WC_BW = CW_J_dot_WC_BW + BW_ν = BW_nu + BW_ν̇ = BW_nu_dot CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇ CW_a_WC = CW_a_WC.reshape(-1, 6) From effbb0e2644832bbfcb751d6c6f580585e09c58d Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Tue, 24 Sep 2024 11:26:53 +0200 Subject: [PATCH 2/2] Update `contact.py` --- src/jaxsim/api/contact.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 75c85df2..d62b45bb 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -208,6 +208,7 @@ def collidable_point_dynamics( data=data, link_forces=link_forces, joint_force_references=joint_force_references, + solver_tol=1e-3, ) aux_data = dict()