diff --git a/ivy/stateful/utilities.py b/ivy/stateful/utilities.py index b308188db60e..5106c151b4c4 100644 --- a/ivy/stateful/utilities.py +++ b/ivy/stateful/utilities.py @@ -569,19 +569,6 @@ def _compute_module_dict_tf(model, prefix=""): "to install it and restart your interpreter to see the changes." ) from exc - try: - assert isinstance( - model_pt, torch.nn.Module - ), "The original model must be an instance of `torch.nn.Module` (PyTorch)." - except AssertionError as e: - raise TypeError("PyTorch model is required as the first argument.") from e - - try: - assert isinstance( - model_tf, (tf.keras.Model,tf.keras.layers.Layer) - ), "The second model must be an instance of `tf.keras.Model` (TensorFlow)." - except AssertionError as e: - raise TypeError("The second model must be a TensorFlow model.") from e if hasattr(model_tf, "named_parameters"): _sync_models_torch_and_tf(model_pt, model_tf) @@ -714,20 +701,6 @@ def _compute_module_dict_jax(model, prefix=""): "to install it and restart your interpreter to see the changes." ) from exc - try: - assert isinstance( - model_pt, torch.nn.Module - ), "The original model must be an instance of `torch.nn.Module` (PyTorch)." - except AssertionError as e: - raise TypeError("PyTorch model is required as the first argument.") from e - - try: - assert isinstance( - model_jax, nnx.Module - ), "The second model must be an instance of `nnx.Module`." - except AssertionError as e: - raise TypeError("The second model must be a Flax model.") from e - if hasattr(model_jax, "named_parameters"): _sync_models_torch_and_jax(model_pt, model_jax)