Skip to content

Commit

Permalink
Remove obsolete skipTests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 647587855
  • Loading branch information
shaobohou authored and TF2JAXDev committed Jun 28, 2024
1 parent 0a5da95 commit 4f1ee99
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 34 deletions.
29 changes: 0 additions & 29 deletions tf2jax/_src/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,6 @@ def test_avg_pool(self, padding, data_format):
strides = (1, 3, 2, 1)

if data_format == "NCHW":
if jax.default_backend().lower() == "cpu":
self.skipTest("TensorFlow AvgPool does not support NCHW on CPU.")
inputs = np.transpose(inputs, [0, 3, 1, 2])
ksize = _reorder(ksize, [0, 3, 1, 2])
strides = _reorder(strides, [0, 3, 1, 2])
Expand Down Expand Up @@ -620,8 +618,6 @@ def test_conv2d(self, input_shape, filter_shape, strides, padding,
filters = np.random.normal(size=filter_shape).astype(np.float32)
inputs = np.random.normal(size=input_shape).astype(np.float32)
if data_format == "NCHW":
if jax.default_backend().lower() == "cpu":
self.skipTest("TensorFlow Conv2D does not support NCHW on CPU.")
inputs = np.transpose(inputs, [0, 3, 1, 2])
strides = _reorder(strides, [0, 3, 1, 2])
dilations = _reorder(dilations, [0, 3, 1, 2])
Expand Down Expand Up @@ -664,9 +660,6 @@ def test_conv2d_transpose(self, data_format):
filters = np.random.normal(size=[7, 7, 128, 13]).astype(np.float32)
inputs = np.random.normal(size=[3, 4, 4, 13]).astype(np.float32)
if data_format == "NCHW":
if jax.default_backend().lower() == "cpu":
self.skipTest(
"TensorFlow Conv2DBackpropInput does not support NCHW on CPU.")
inputs = np.transpose(inputs, [0, 3, 1, 2])
strides = _reorder(strides, [0, 3, 1, 2])
dilations = _reorder(dilations, [0, 3, 1, 2])
Expand Down Expand Up @@ -768,9 +761,6 @@ def test_depthwise_conv2d(self, use_explicit_paddings, data_format):
dilations = [1, 1, 1, 1]

if data_format == "NCHW":
if jax.default_backend().lower() == "cpu":
self.skipTest(
"TensorFlow DepthwiseConv2dNative does not support NCHW on CPU.")
inputs = np.transpose(inputs, [0, 3, 1, 2])
strides = _reorder(strides, [0, 3, 1, 2])
dilations = _reorder(dilations, [0, 3, 1, 2])
Expand Down Expand Up @@ -1081,13 +1071,6 @@ def test_igamma_ops(self):

@chex.variants(with_jit=True, without_jit=True)
def test_inplace_add(self):
if test_util.parse_version(tf.version.VERSION) >= test_util.parse_version(
"2.14.0"
):
self.skipTest(
f"Requires earlier than tf 2.14.0, found {tf.version.VERSION}."
)

np.random.seed(42)

@tf.function
Expand Down Expand Up @@ -1116,13 +1099,6 @@ def inplace_add(x, idx, val):

@chex.variants(with_jit=True, without_jit=True)
def test_inplace_update(self):
if test_util.parse_version(tf.version.VERSION) >= test_util.parse_version(
"2.14.0"
):
self.skipTest(
f"Requires earlier than tf 2.14.0, found {tf.version.VERSION}."
)

np.random.seed(42)

@tf.function
Expand Down Expand Up @@ -1278,8 +1254,6 @@ def test_max_pool(self, padding, data_format):
strides = (1, 3, 2, 1)

if data_format == "NCHW":
if jax.default_backend().lower() == "cpu":
self.skipTest("TensorFlow MaxPool does not support NCHW on CPU.")
inputs = np.transpose(inputs, [0, 3, 1, 2])
ksize = _reorder(ksize, [0, 3, 1, 2])
strides = _reorder(strides, [0, 3, 1, 2])
Expand Down Expand Up @@ -2186,9 +2160,6 @@ def while_loop(x):
return tf.while_loop(cond, body, [x, step])
self._test_convert(while_loop, inputs)

if jax.default_backend().lower() == "gpu":
self.skipTest("Skip remaining tests on GPU due to CUDA errors.")

def raw_stateless_while(x):
loop_vars = [x, step]
return tf.raw_ops.StatelessWhile(
Expand Down
5 changes: 0 additions & 5 deletions tf2jax/_src/roundtrip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,11 +1166,6 @@ def forward(a, b):
grad_tols=tols)

def test_explicit_native_serialization(self):
if test_util.parse_version(tf.version.VERSION) < test_util.parse_version(
"2.12.0"
):
self.skipTest(f"Requires tf 2.12.0 or later, found {tf.version.VERSION}.")

def forward(x):
return x + 3.14

Expand Down

0 comments on commit 4f1ee99

Please sign in to comment.