diff --git a/tf2jax/_src/ops_test.py b/tf2jax/_src/ops_test.py index 2a6da77..aa4e467 100644 --- a/tf2jax/_src/ops_test.py +++ b/tf2jax/_src/ops_test.py @@ -339,8 +339,6 @@ def test_avg_pool(self, padding, data_format): strides = (1, 3, 2, 1) if data_format == "NCHW": - if jax.default_backend().lower() == "cpu": - self.skipTest("TensorFlow AvgPool does not support NCHW on CPU.") inputs = np.transpose(inputs, [0, 3, 1, 2]) ksize = _reorder(ksize, [0, 3, 1, 2]) strides = _reorder(strides, [0, 3, 1, 2]) @@ -620,8 +618,6 @@ def test_conv2d(self, input_shape, filter_shape, strides, padding, filters = np.random.normal(size=filter_shape).astype(np.float32) inputs = np.random.normal(size=input_shape).astype(np.float32) if data_format == "NCHW": - if jax.default_backend().lower() == "cpu": - self.skipTest("TensorFlow Conv2D does not support NCHW on CPU.") inputs = np.transpose(inputs, [0, 3, 1, 2]) strides = _reorder(strides, [0, 3, 1, 2]) dilations = _reorder(dilations, [0, 3, 1, 2]) @@ -664,9 +660,6 @@ def test_conv2d_transpose(self, data_format): filters = np.random.normal(size=[7, 7, 128, 13]).astype(np.float32) inputs = np.random.normal(size=[3, 4, 4, 13]).astype(np.float32) if data_format == "NCHW": - if jax.default_backend().lower() == "cpu": - self.skipTest( - "TensorFlow Conv2DBackpropInput does not support NCHW on CPU.") inputs = np.transpose(inputs, [0, 3, 1, 2]) strides = _reorder(strides, [0, 3, 1, 2]) dilations = _reorder(dilations, [0, 3, 1, 2]) @@ -768,9 +761,6 @@ def test_depthwise_conv2d(self, use_explicit_paddings, data_format): dilations = [1, 1, 1, 1] if data_format == "NCHW": - if jax.default_backend().lower() == "cpu": - self.skipTest( - "TensorFlow DepthwiseConv2dNative does not support NCHW on CPU.") inputs = np.transpose(inputs, [0, 3, 1, 2]) strides = _reorder(strides, [0, 3, 1, 2]) dilations = _reorder(dilations, [0, 3, 1, 2]) @@ -1081,13 +1071,6 @@ def test_igamma_ops(self): @chex.variants(with_jit=True, without_jit=True) def test_inplace_add(self): - if test_util.parse_version(tf.version.VERSION) >= test_util.parse_version( - "2.14.0" - ): - self.skipTest( - f"Requires earlier than tf 2.14.0, found {tf.version.VERSION}." - ) - np.random.seed(42) @tf.function @@ -1116,13 +1099,6 @@ def inplace_add(x, idx, val): @chex.variants(with_jit=True, without_jit=True) def test_inplace_update(self): - if test_util.parse_version(tf.version.VERSION) >= test_util.parse_version( - "2.14.0" - ): - self.skipTest( - f"Requires earlier than tf 2.14.0, found {tf.version.VERSION}." - ) - np.random.seed(42) @tf.function @@ -1278,8 +1254,6 @@ def test_max_pool(self, padding, data_format): strides = (1, 3, 2, 1) if data_format == "NCHW": - if jax.default_backend().lower() == "cpu": - self.skipTest("TensorFlow MaxPool does not support NCHW on CPU.") inputs = np.transpose(inputs, [0, 3, 1, 2]) ksize = _reorder(ksize, [0, 3, 1, 2]) strides = _reorder(strides, [0, 3, 1, 2]) @@ -2186,9 +2160,6 @@ def while_loop(x): return tf.while_loop(cond, body, [x, step]) self._test_convert(while_loop, inputs) - if jax.default_backend().lower() == "gpu": - self.skipTest("Skip remaining tests on GPU due to CUDA errors.") - def raw_stateless_while(x): loop_vars = [x, step] return tf.raw_ops.StatelessWhile( diff --git a/tf2jax/_src/roundtrip_test.py b/tf2jax/_src/roundtrip_test.py index 8738fff..5302710 100644 --- a/tf2jax/_src/roundtrip_test.py +++ b/tf2jax/_src/roundtrip_test.py @@ -1166,11 +1166,6 @@ def forward(a, b): grad_tols=tols) def test_explicit_native_serialization(self): - if test_util.parse_version(tf.version.VERSION) < test_util.parse_version( - "2.12.0" - ): - self.skipTest(f"Requires tf 2.12.0 or later, found {tf.version.VERSION}.") - def forward(x): return x + 3.14