diff --git a/environment.yml b/environment.yml index 8a3d6de0b..334dba22c 100644 --- a/environment.yml +++ b/environment.yml @@ -14,7 +14,7 @@ dependencies: - jax-dataclasses >= 1.4.0 - pptree - qpax - - rod >= 0.3.0 + - rod >= 0.3.3 - typing_extensions # python<3.12 # ==================================== # Optional dependencies from setup.cfg diff --git a/examples/PD_controller.ipynb b/examples/PD_controller.ipynb index b68c5b783..5f401ff1c 100644 --- a/examples/PD_controller.ipynb +++ b/examples/PD_controller.ipynb @@ -100,7 +100,7 @@ "num_steps = int(integration_time / dt)\n", "\n", "model = js.model.JaxSimModel.build_from_model_description(\n", - " model_description=model_urdf_string, is_urdf=True\n", + " model_description=model_urdf_string\n", ")\n", "data = js.data.JaxSimModelData.build(model=model)\n", "integrator = integrators.fixed_step.RungeKutta4SO3.build(\n", diff --git a/pyproject.toml b/pyproject.toml index 032f970a8..843c85bfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ dependencies = [ "jax_dataclasses >= 1.4.0", "pptree", "qpax", - "rod >= 0.3.0", + "rod >= 0.3.3", "typing_extensions ; python_version < '3.12'", ] diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index d9c0a1edc..21d1ef923 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -106,8 +106,8 @@ def build_from_model_description( terrain: The optional terrain to consider. is_urdf: - Whether the model description is a URDF or an SDF. This is - automatically inferred if the model description is a path to a file. + The optional flag to force the model description to be parsed as a + URDF or a SDF. This is otherwise automatically inferred. considered_joints: The list of joints to consider. If None, all joints are considered. @@ -120,7 +120,7 @@ def build_from_model_description( # Parse the input resource (either a path to file or a string with the URDF/SDF) # and build the -intermediate- model description. intermediate_description = jaxsim.parsers.rod.build_model_description( - model_description=model_description, is_urdf=is_urdf + model_description=model_description ) # Lump links together if not all joints are considered. diff --git a/src/jaxsim/mujoco/loaders.py b/src/jaxsim/mujoco/loaders.py index 3ea44233b..a15a3afc7 100644 --- a/src/jaxsim/mujoco/loaders.py +++ b/src/jaxsim/mujoco/loaders.py @@ -25,7 +25,7 @@ def load_rod_model( Args: model_description: The URDF/SDF file or ROD model to load. - is_urdf: Whether the model description is a URDF file. + is_urdf: Whether to force parsing the model description as a URDF file. model_name: The name of the model to load from the resource. Returns: diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index 76aadf435..cee3696bb 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -33,7 +33,7 @@ class SDFData(NamedTuple): def extract_model_data( - model_description: pathlib.Path | str | rod.Model, + model_description: pathlib.Path | str | rod.Model | rod.Sdf, model_name: str | None = None, is_urdf: bool | None = None, ) -> SDFData: @@ -41,30 +41,37 @@ def extract_model_data( Extract data from an SDF/URDF resource useful to build a JaxSim model. Args: - model_description: A path to an SDF/URDF file, a string containing its content, - or a pre-parsed/pre-built rod model. + model_description: + A path to an SDF/URDF file, a string containing its content, or + a pre-parsed/pre-built rod model. model_name: The name of the model to extract from the SDF resource. - is_urdf: Whether the SDF resource is a URDF file. Needed only if model_description - is a URDF string. + is_urdf: + Whether to force parsing the resource as a URDF file. Automatically + detected if not provided. Returns: The extracted model data. """ - if isinstance(model_description, rod.Model): - sdf_model = model_description - else: - # Parse the SDF resource. - sdf_element = rod.Sdf.load(sdf=model_description, is_urdf=is_urdf) - - if len(sdf_element.models()) == 0: - raise RuntimeError("Failed to find any model in SDF resource") - - # Assume the SDF resource has only one model, or the desired model name is given. - sdf_models = {m.name: m for m in sdf_element.models()} - sdf_model = ( - sdf_element.models()[0] if len(sdf_models) == 1 else sdf_models[model_name] - ) + match model_description: + case rod.Model(): + sdf_model = model_description + case rod.Sdf() | str() | pathlib.Path(): + sdf_element = ( + model_description + if isinstance(model_description, rod.Sdf) + else rod.Sdf.load(sdf=model_description, is_urdf=is_urdf) + ) + if not sdf_element.models(): + raise RuntimeError("Failed to find any model in SDF resource") + + # Assume the SDF resource has only one model, or the desired model name is given. + sdf_models = {m.name: m for m in sdf_element.models()} + sdf_model = ( + sdf_element.models()[0] + if len(sdf_models) == 1 + else sdf_models[model_name] + ) # Log model name. logging.debug(msg=f"Found model '{sdf_model.name}' in SDF resource") @@ -344,7 +351,7 @@ def extract_model_data( def build_model_description( model_description: pathlib.Path | str | rod.Model, - is_urdf: bool | None = False, + is_urdf: bool | None = None, ) -> descriptions.ModelDescription: """ Builds a model description from an SDF/URDF resource. @@ -352,8 +359,9 @@ def build_model_description( Args: model_description: A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model. - is_urdf: Whether the SDF resource is a URDF file. Needed only if model_description - is a URDF string. + is_urdf: Whether the force parsing the resource as a URDF file. Automatically + detected if not provided. + Returns: The parsed model description. """ diff --git a/tests/conftest.py b/tests/conftest.py index 10349f4c2..b3d5fb1a5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -83,24 +83,9 @@ def build_jaxsim_model( A JaxSim model built from the provided description. """ - is_urdf = None - - # If the provided description is a string, automatically detect if it - # contains the content of a URDF or SDF file. - if isinstance(model_description, str): - if "" in model_description: - is_urdf = False - - else: - is_urdf = None - # Build the JaxSim model. model = js.model.JaxSimModel.build_from_model_description( model_description=model_description, - is_urdf=is_urdf, ) return model diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index f055c18a6..6456f2645 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -85,7 +85,7 @@ def test_contact_jacobian_derivative( W_p_Ci = model.kin_dyn_parameters.contact_parameters.point # Load the model in ROD. - rod_model = rod.Sdf.load(sdf=model.built_from, is_urdf=True).model + rod_model = rod.Sdf.load(sdf=model.built_from).model # Add dummy frames on the contact points. for idx, (link_name, W_p_C) in enumerate( diff --git a/tests/test_pytree.py b/tests/test_pytree.py index 8d063dc87..561d5ac53 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -14,14 +14,12 @@ def test_call_jit_compiled_function_passing_different_objects( # Create a first model from the URDF. model1 = js.model.JaxSimModel.build_from_model_description( - model_description=ergocub_model_description_path, - is_urdf=True, + model_description=ergocub_model_description_path ) # Create a second model from the URDF. model2 = js.model.JaxSimModel.build_from_model_description( - model_description=ergocub_model_description_path, - is_urdf=True, + model_description=ergocub_model_description_path ) # The objects should be different, but the comparison should return True.