From e39b27bf7728960ca1af2bc6dfb8247cb8e10691 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Mon, 23 Sep 2024 16:00:06 +0200 Subject: [PATCH 1/2] Update `jaxsim.api.contact.collidable_point_forces` Add external link forces and joint torques as input arguments that are then passed to `collidable_point_dynamics` --- src/jaxsim/api/contact.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 956e72e7..75c85df2 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -93,7 +93,10 @@ def collidable_point_velocities( @jax.jit def collidable_point_forces( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, ) -> jtp.Matrix: """ Compute the 6D forces applied to each collidable point. @@ -101,13 +104,23 @@ def collidable_point_forces( Args: model: The model to consider. data: The data of the considered model. + link_forces: + The 6D external forces to apply to the links expressed in the same + representation of data. + joint_force_references: + The joint force references to apply to the joints. Returns: The 6D forces applied to each collidable point expressed in the frame corresponding to the active representation. """ - f_Ci, _ = collidable_point_dynamics(model=model, data=data) + f_Ci, _ = collidable_point_dynamics( + model=model, + data=data, + link_forces=link_forces, + joint_force_references=joint_force_references, + ) return f_Ci From 9af121c2be213cb3aceb60a0329a5429320532f0 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Mon, 23 Sep 2024 16:01:27 +0200 Subject: [PATCH 2/2] Update `jaxsim.api.model.link_contact_forces` Add external link forces and joint torques as input arguments that are then considered when computing the link contact forces --- src/jaxsim/api/model.py | 48 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 602def3f..d9c0a1ed 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1731,7 +1731,10 @@ def body_to_other_representation( @jax.jit def link_contact_forces( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, ) -> jtp.Matrix: """ Compute the 6D contact forces of all links of the model. @@ -1739,6 +1742,11 @@ def link_contact_forces( Args: model: The model to consider. data: The data of the considered model. + link_forces: + The 6D external forces to apply to the links expressed in the same + representation of data. + joint_force_references: + The joint force references to apply to the joints. Returns: A `(nL, 6)` array containing the stacked 6D contact forces of the links, @@ -1749,10 +1757,44 @@ def link_contact_forces( # `jaxsim.api.ode.system_velocity_dynamics`. We cannot merge them since # there we need to get also aux_data. + # Build link forces if not provided. + # These forces are expressed in the frame corresponding to the velocity + # representation of data. + O_f_L = ( + jnp.atleast_2d(link_forces.squeeze()) + if link_forces is not None + else jnp.zeros((model.number_of_links(), 6)) + ).astype(float) + + # Build joint force references if not provided. + joint_force_references = ( + jnp.atleast_1d(joint_force_references) + if joint_force_references is not None + else jnp.zeros(model.dofs()) + ) + + # We expect that the 6D forces included in the `link_forces` argument are expressed + # in the frame corresponding to the velocity representation of `data`. + input_references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + velocity_representation=data.velocity_representation, + link_forces=O_f_L, + joint_force_references=joint_force_references, + ) + # Compute the 6D forces applied to each collidable point expressed in the # inertial frame. - with data.switch_velocity_representation(VelRepr.Inertial): - W_f_C = js.contact.collidable_point_forces(model=model, data=data) + with ( + data.switch_velocity_representation(VelRepr.Inertial), + input_references.switch_velocity_representation(VelRepr.Inertial), + ): + W_f_C = js.contact.collidable_point_forces( + model=model, + data=data, + link_forces=input_references.link_forces(), + joint_force_references=input_references.joint_force_references(), + ) # Construct the vector defining the parent link index of each collidable point. # We use this vector to sum the 6D forces of all collidable points rigidly