From 9e576653b8a5c0a3f993dab9626f202c1d8d5308 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 19 Sep 2024 17:51:13 +0200 Subject: [PATCH 1/4] Remove `is_urdf` argument from the API --- src/jaxsim/api/model.py | 6 ++-- src/jaxsim/mujoco/loaders.py | 4 +-- src/jaxsim/parsers/rod/parser.py | 50 ++++++++++++++++++-------------- tests/conftest.py | 15 ---------- tests/test_api_contact.py | 2 +- tests/test_pytree.py | 6 ++-- 6 files changed, 36 insertions(+), 47 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index d9c0a1edc..21d1ef923 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -106,8 +106,8 @@ def build_from_model_description( terrain: The optional terrain to consider. is_urdf: - Whether the model description is a URDF or an SDF. This is - automatically inferred if the model description is a path to a file. + The optional flag to force the model description to be parsed as a + URDF or a SDF. This is otherwise automatically inferred. considered_joints: The list of joints to consider. If None, all joints are considered. @@ -120,7 +120,7 @@ def build_from_model_description( # Parse the input resource (either a path to file or a string with the URDF/SDF) # and build the -intermediate- model description. intermediate_description = jaxsim.parsers.rod.build_model_description( - model_description=model_description, is_urdf=is_urdf + model_description=model_description ) # Lump links together if not all joints are considered. diff --git a/src/jaxsim/mujoco/loaders.py b/src/jaxsim/mujoco/loaders.py index 3ea44233b..e4426a9e4 100644 --- a/src/jaxsim/mujoco/loaders.py +++ b/src/jaxsim/mujoco/loaders.py @@ -25,7 +25,7 @@ def load_rod_model( Args: model_description: The URDF/SDF file or ROD model to load. - is_urdf: Whether the model description is a URDF file. + is_urdf: Whether to force parsing the model description as a URDF file. model_name: The name of the model to load from the resource. Returns: @@ -558,7 +558,6 @@ def convert( # Get the ROD model. rod_model = load_rod_model( model_description=urdf, - is_urdf=True, model_name=model_name, ) @@ -605,7 +604,6 @@ def convert( # Get the ROD model. rod_model = load_rod_model( model_description=sdf, - is_urdf=False, model_name=model_name, ) diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index 76aadf435..00f237a79 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -41,30 +41,37 @@ def extract_model_data( Extract data from an SDF/URDF resource useful to build a JaxSim model. Args: - model_description: A path to an SDF/URDF file, a string containing its content, - or a pre-parsed/pre-built rod model. + model_description: + A path to an SDF/URDF file, a string containing its content, or + a pre-parsed/pre-built rod model. model_name: The name of the model to extract from the SDF resource. - is_urdf: Whether the SDF resource is a URDF file. Needed only if model_description - is a URDF string. + is_urdf: + Whether to force parsing the resource as a URDF file. Automatically + detected if not provided. Returns: The extracted model data. """ - if isinstance(model_description, rod.Model): - sdf_model = model_description - else: - # Parse the SDF resource. - sdf_element = rod.Sdf.load(sdf=model_description, is_urdf=is_urdf) - - if len(sdf_element.models()) == 0: - raise RuntimeError("Failed to find any model in SDF resource") - - # Assume the SDF resource has only one model, or the desired model name is given. - sdf_models = {m.name: m for m in sdf_element.models()} - sdf_model = ( - sdf_element.models()[0] if len(sdf_models) == 1 else sdf_models[model_name] - ) + match model_description: + case rod.Model(): + sdf_model = model_description + case rod.Sdf(None): + sdf_model = model_description.model + case str() | pathlib.Path(): + # Parse the SDF resource. + sdf_element = rod.Sdf.load(sdf=model_description, is_urdf=is_urdf) + + if len(sdf_element.models()) == 0: + raise RuntimeError("Failed to find any model in SDF resource") + + # Assume the SDF resource has only one model, or the desired model name is given. + sdf_models = {m.name: m for m in sdf_element.models()} + sdf_model = ( + sdf_element.models()[0] + if len(sdf_models) == 1 + else sdf_models[model_name] + ) # Log model name. logging.debug(msg=f"Found model '{sdf_model.name}' in SDF resource") @@ -344,7 +351,7 @@ def extract_model_data( def build_model_description( model_description: pathlib.Path | str | rod.Model, - is_urdf: bool | None = False, + is_urdf: bool | None = None, ) -> descriptions.ModelDescription: """ Builds a model description from an SDF/URDF resource. @@ -352,8 +359,9 @@ def build_model_description( Args: model_description: A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model. - is_urdf: Whether the SDF resource is a URDF file. Needed only if model_description - is a URDF string. + is_urdf: Whether the force parsing the resource as a URDF file. Automatically + detected if not provided. + Returns: The parsed model description. """ diff --git a/tests/conftest.py b/tests/conftest.py index 10349f4c2..b3d5fb1a5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -83,24 +83,9 @@ def build_jaxsim_model( A JaxSim model built from the provided description. """ - is_urdf = None - - # If the provided description is a string, automatically detect if it - # contains the content of a URDF or SDF file. - if isinstance(model_description, str): - if "" in model_description: - is_urdf = False - - else: - is_urdf = None - # Build the JaxSim model. model = js.model.JaxSimModel.build_from_model_description( model_description=model_description, - is_urdf=is_urdf, ) return model diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index f055c18a6..6456f2645 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -85,7 +85,7 @@ def test_contact_jacobian_derivative( W_p_Ci = model.kin_dyn_parameters.contact_parameters.point # Load the model in ROD. - rod_model = rod.Sdf.load(sdf=model.built_from, is_urdf=True).model + rod_model = rod.Sdf.load(sdf=model.built_from).model # Add dummy frames on the contact points. for idx, (link_name, W_p_C) in enumerate( diff --git a/tests/test_pytree.py b/tests/test_pytree.py index 8d063dc87..561d5ac53 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -14,14 +14,12 @@ def test_call_jit_compiled_function_passing_different_objects( # Create a first model from the URDF. model1 = js.model.JaxSimModel.build_from_model_description( - model_description=ergocub_model_description_path, - is_urdf=True, + model_description=ergocub_model_description_path ) # Create a second model from the URDF. model2 = js.model.JaxSimModel.build_from_model_description( - model_description=ergocub_model_description_path, - is_urdf=True, + model_description=ergocub_model_description_path ) # The objects should be different, but the comparison should return True. From abaa1fe13282ce6a7b341ec756f9a95c20643c89 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 19 Sep 2024 18:12:09 +0200 Subject: [PATCH 2/4] Remove `is_urdf` argument from example notebooks --- examples/PD_controller.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/PD_controller.ipynb b/examples/PD_controller.ipynb index b68c5b783..5f401ff1c 100644 --- a/examples/PD_controller.ipynb +++ b/examples/PD_controller.ipynb @@ -100,7 +100,7 @@ "num_steps = int(integration_time / dt)\n", "\n", "model = js.model.JaxSimModel.build_from_model_description(\n", - " model_description=model_urdf_string, is_urdf=True\n", + " model_description=model_urdf_string\n", ")\n", "data = js.data.JaxSimModelData.build(model=model)\n", "integrator = integrators.fixed_step.RungeKutta4SO3.build(\n", From f8b9f4c234e4035c9d1831afe9c32c15f93d7b81 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 23 Sep 2024 10:01:53 +0200 Subject: [PATCH 3/4] Bump minimum `rod` requirement to 0.3.3 --- environment.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/environment.yml b/environment.yml index 8a3d6de0b..334dba22c 100644 --- a/environment.yml +++ b/environment.yml @@ -14,7 +14,7 @@ dependencies: - jax-dataclasses >= 1.4.0 - pptree - qpax - - rod >= 0.3.0 + - rod >= 0.3.3 - typing_extensions # python<3.12 # ==================================== # Optional dependencies from setup.cfg diff --git a/pyproject.toml b/pyproject.toml index 032f970a8..843c85bfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ dependencies = [ "jax_dataclasses >= 1.4.0", "pptree", "qpax", - "rod >= 0.3.0", + "rod >= 0.3.3", "typing_extensions ; python_version < '3.12'", ] From cf1ff07d9d9249c1c69954507b7fc6c809d618ae Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 25 Sep 2024 11:45:54 +0200 Subject: [PATCH 4/4] Address suggestions from code review Co-authored-by: Diego Ferigo --- src/jaxsim/mujoco/loaders.py | 2 ++ src/jaxsim/parsers/rod/parser.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/jaxsim/mujoco/loaders.py b/src/jaxsim/mujoco/loaders.py index e4426a9e4..a15a3afc7 100644 --- a/src/jaxsim/mujoco/loaders.py +++ b/src/jaxsim/mujoco/loaders.py @@ -558,6 +558,7 @@ def convert( # Get the ROD model. rod_model = load_rod_model( model_description=urdf, + is_urdf=True, model_name=model_name, ) @@ -604,6 +605,7 @@ def convert( # Get the ROD model. rod_model = load_rod_model( model_description=sdf, + is_urdf=False, model_name=model_name, ) diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index 00f237a79..cee3696bb 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -33,7 +33,7 @@ class SDFData(NamedTuple): def extract_model_data( - model_description: pathlib.Path | str | rod.Model, + model_description: pathlib.Path | str | rod.Model | rod.Sdf, model_name: str | None = None, is_urdf: bool | None = None, ) -> SDFData: @@ -56,13 +56,13 @@ def extract_model_data( match model_description: case rod.Model(): sdf_model = model_description - case rod.Sdf(None): - sdf_model = model_description.model - case str() | pathlib.Path(): - # Parse the SDF resource. - sdf_element = rod.Sdf.load(sdf=model_description, is_urdf=is_urdf) - - if len(sdf_element.models()) == 0: + case rod.Sdf() | str() | pathlib.Path(): + sdf_element = ( + model_description + if isinstance(model_description, rod.Sdf) + else rod.Sdf.load(sdf=model_description, is_urdf=is_urdf) + ) + if not sdf_element.models(): raise RuntimeError("Failed to find any model in SDF resource") # Assume the SDF resource has only one model, or the desired model name is given.