Skip to content

Commit

Permalink
add basic test to generator and encoder functions
Browse files Browse the repository at this point in the history
  • Loading branch information
CosmoMatt committed Jun 14, 2024
1 parent 68d2033 commit 9300f08
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
import pytest
import s2scat

L_to_test = [8]
N_to_test = [3]
recursive_transform = [False, True]
isotropic = [False, True]


@pytest.mark.parametrize("L", L_to_test)
@pytest.mark.parametrize("N", N_to_test)
@pytest.mark.parametrize("recursive", recursive_transform)
@pytest.mark.parametrize("isotropic", isotropic)
def test_generator(L: int, N: int, recursive: bool, isotropic: bool):
xlm = jnp.array(
np.random.randn(L, L) + 1j * np.random.randn(L, L), dtype=jnp.complex64
)
generator = s2scat.build_generator(xlm, L, N, 0, True, recursive, isotropic)
key = jax.random.PRNGKey(0)
xlm_new = generator(key, count=10, niter=10)
assert xlm_new.shape == (10, L, 2 * L - 1)


@pytest.mark.parametrize("L", L_to_test)
@pytest.mark.parametrize("N", N_to_test)
@pytest.mark.parametrize("recursive", recursive_transform)
@pytest.mark.parametrize("isotropic", isotropic)
def test_encoder(L: int, N: int, recursive: bool, isotropic: bool):
xlm = jnp.array(
np.random.randn(10, L, L) + 1j * np.random.randn(10, L, L),
dtype=jnp.complex128,
)
encoder = s2scat.build_encoder(L, N, 0, True, recursive, isotropic)
latents = encoder(xlm)
assert len(latents) == 6

0 comments on commit 9300f08

Please sign in to comment.