Skip to content

Commit

Permalink
Merge pull request #18 from astro-informatics/housekeeping/naming_and…
Browse files Browse the repository at this point in the history
…_aliasing

updaing naming conventions and function aliasing
  • Loading branch information
CosmoMatt authored Jun 17, 2024
2 parents ee06896 + cc307e1 commit b049934
Show file tree
Hide file tree
Showing 42 changed files with 1,505 additions and 800 deletions.
2 changes: 1 addition & 1 deletion .pip_readme.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
:target: add_link_here


Scattering covariance transform on the sphere
Differentiable scattering covariances on the sphere
=================================================================================================================

``S2SCAT`` is a Python package for computing third generation scattering covariances on the
Expand Down
62 changes: 36 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# Differentiable scattering covariances on the sphere

`S2SCAT` is a Python package for computing scattering covariances on the sphere ([Mousset et al. 2024](https://arxiv.org/abs/xxxx.xxxxx)) using JAX. It exploits autodiff to provide differentiable transforms, which are also deployable on hardware accelerators (e.g. GPUs and TPUs), leveraging the differentiable and accelerated spherical harmonic and wavelet transforms implemented in [S2FFT](https://github.com/astro-informatics/s2fft) and [S2WAV](https://github.com/astro-informatics/s2wav), respectively.
`S2SCAT` is a Python package for computing scattering covariances on the sphere ([Mousset et al. 2024](https://arxiv.org/abs/xxxx.xxxxx)) using JAX. It exploits autodiff to provide differentiable transforms, which are also deployable on hardware accelerators (e.g. GPUs and TPUs), leveraging the differentiable and accelerated spherical harmonic and wavelet transforms implemented in [S2FFT](https://github.com/astro-informatics/s2fft) and [S2WAV](https://github.com/astro-informatics/s2wav), respectively. Scattering covariances are useful both for field-level generative modelling of complex non-Gaussian textures and for statistical compression of high dimensional field-level data, a key step of e.g. simulation based inference.

> [!IMPORTANT]
> It is worth highlighting that the input to `S2SCAT` are spherical harmonic coefficients, which can be generated with whichever software package you prefer, e.g. [`S2FFT`](https://github.com/astro-informatics/s2fft) or [`healpy`](https://healpy.readthedocs.io/en/latest/). Just ensure your harmonic coefficients are indexed using our convention; helper functions for this reindexing can be found in [`S2FFT`](https://github.com/astro-informatics/s2fft).
Expand All @@ -23,6 +23,8 @@ Ballpark compute times (when running on an 40GB A100 GPU) and compression levels
| Precompute | L=512, N=3 | ~90ms | ~190ms | ~20s | 2,618,880 | ~ 63,000 (97.594%) | ~504 (99.981%) |
| On-the-fly | L=2048, N=3 | ~18s | ~40s | ~5m | 41,932,800 | ~ 123,750 (99.705%) | ~ 990 (99.998%) |

Note that these times are not batched, so in practice may be substantially faster.

## Scattering covariances :dna:

<p align="center">
Expand All @@ -43,16 +45,36 @@ $$S_4^{\lambda_1, \lambda_2, \lambda_3} = \text{Cov} \left[W^{\lambda_1}|W^{\lam

where $W^{\lambda} I$ denotes the wavelet transform of field $I$ at scale $j$ and direction $\gamma$, which we group into a single label $\lambda=(j,\gamma)$.

This statistical representation characterises the power and sparsity at given scales, as well as covariant features between different wavelet scale and directions, which can effectively capture complex non-Gaussian structural information, e.g. filamentary structure.
This statistical representation characterises the power and sparsity at given scales, as well as covariant features between different wavelet scale and directions, which can effectively capture complex non-Gaussian structural information, e.g. filamentary structure.

Using the recently released JAX spherical harmonic code [`S2FFT`](https://github.com/astro-informatics/s2fft) ([Price & McEwen 2024](https://arxiv.org/abs/2311.14670)) and spherical wavelet transform code [`S2WAV`](https://github.com/astro-informatics/s2wav) ([Price et al. 2024](<https://arxiv.org/abs/2402.01282)) in the `S2SCAT` code we extends scattering covariances to the sphere, which are necessary for their application to generative modelling of wide-field cosmological fields ([Mousset et al. 2024](https://arxiv.org/abs/xxxx.xxxxx)).

## Usage :rocket:

To import and use `S2SCAT` is as simple follows:

``` python
import s2scat, jax
# For statistical compression
encoder = s2scat.build_encoder(L, N) # Returns a callable compression model.
covariance_statistics = encoder(alm) # Generate statistics (can be batched).

# For generative modelling
key = jax.random.PRNGKey(seed)
generator = s2scat.build_generator(alm, L, N) # Returns a callable generative model.
new_samples = generator(key, 10) # Generate 10 new spherical textures.
```

For further details on usage see the [documentation](https://astro-informatics.github.io/s2scat/) and associated [notebooks](https://astro-informatics.github.io/s2scat/notebooks/).

## Package Directory Structure :art:

``` bash
s2scat/
├── core/ # Top-level functionality:
│ ├─ scatter.py # - Scattering covariance transform.
│ ├─ compress.py # - Statistical compression functions.
│ ├─ synthesis.py # - Synthesis optimisation functions.
├── representation.py # - Scattering covariance transform.
├── compression.py # - Statistical compression functions.
├── optimisation.py # - Optimisation algorithm wrappers.
├── generation.py # - Latent encoder and Generative decoder.
├── operators/ # Internal functionality:
│ ├─ spherical.py # - Specific spherical operations, e.g. batched SHTs.
Expand Down Expand Up @@ -94,21 +116,6 @@ pytest tests/

Documentation for the released version is available [here](https://astro-informatics.github.io/s2scat/).

## Usage :rocket:

To import and use `S2SCAT` is as simple follows:

``` python
import s2scat

# Given harmonic bandlimit L, azimuthal bandlimit N and spherical harmonic coefficients flm

config = s2scat.configure(L, N)
covariances = s2scat.scatter(flm, L, N, config=config)
```

For further details on usage see the [documentation](https://astro-informatics.github.io/s2scat/) and associated [notebooks](https://astro-informatics.github.io/s2scat/notebooks/).

## Contributors

<!-- ALL-CONTRIBUTORS-LIST:START - Do not remove or modify this section -->
Expand Down Expand Up @@ -151,11 +158,14 @@ code builds:

```
@article{price:s2fft,
author = "Matthew A. Price and Jason D. McEwen",
title = "Differentiable and accelerated spherical harmonic and Wigner transforms",
journal = "Journal of Computational Physics, submitted",
year = "2023",
eprint = "arXiv:2311.14670"
author = "Matthew A. Price and Jason D. McEwen",
title = "Differentiable and accelerated spherical harmonic and {W}igner transforms",
journal = "Journal of Computational Physics",
volume = "510",
pages = "113109",
year = "2024",
doi = {10.1016/j.jcp.2024.113109},
eprint = "arXiv:2311.14670"
}
```
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
**************************
Compression
**************************
.. automodule:: s2scat.core.compress
.. automodule:: s2scat.compression
:members:
4 changes: 2 additions & 2 deletions docs/api/core/scatter.rst → docs/api/core/generation.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
:html_theme.sidebar_secondary.remove:

**************************
Scattering
Generation
**************************
.. automodule:: s2scat.core.scatter
.. automodule:: s2scat.generation
:members:
41 changes: 26 additions & 15 deletions docs/api/core/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,60 @@
Core Functions
**************************

.. list-table:: Scattering Operations
.. list-table:: Generative models
:widths: 25 25
:header-rows: 1

* - Function Name
- Description
* - :func:`~s2scat.core.scatter.directional`
* - :func:`~s2scat.generation.build_encoder`
- Builds a scattering covariance encoding function (latent encoder).
* - :func:`~s2scat.generation.build_generator`
- Builds a scattering covariance generator function. (generative model).

.. list-table:: Scattering transforms
:widths: 25 25
:header-rows: 1

* - Function Name
- Description
* - :func:`~s2scat.representation.scatter`
- Compute directional scattering covariances on the sphere (Mousset et al 2024).
* - :func:`~s2scat.core.scatter.directional_c`
* - :func:`~s2scat.representation.scatter_c`
- Compute directional scattering covariances on the sphere using a custom C backend (Mousset et al 2024).

.. list-table:: Compression Operations
.. list-table:: Compression functions
:widths: 25 25
:header-rows: 1

* - Function Name
- Description
* - :func:`~s2scat.core.compress.C01_C11_to_isotropic`
* - :func:`~s2scat.compression.C01_C11_to_isotropic`
- Convert scattering covariances to their isotropic counterpart.

.. list-table:: Synthesis Operations
.. list-table:: Optimisation functions
:widths: 25 25
:header-rows: 1

* - Function Name
- Description
* - :func:`~s2scat.core.synthesis.fit_jaxopt_scipy`
- Minimises the declared loss function starting at params using jaxopt.
* - :func:`~s2scat.core.synthesis.fit_optax`
* - :func:`~s2scat.optimisation.fit_optax`
- Minimises the declared loss function starting at params using optax (adam).
* - :func:`~s2scat.core.synthesis.l2_covariance_loss`
* - :func:`~s2scat.optimisation.l2_covariance_loss`
- L2 loss wrapper for the scattering covariance.
* - :func:`~s2scat.core.synthesis.l2_loss`
* - :func:`~s2scat.optimisation.l2_loss`
- L2 loss for a single scattering covariance.
* - :func:`~s2scat.core.synthesis.get_P00prime`
* - :func:`~s2scat.optimisation.get_P00prime`
- Computes P00prime which is the averaged power within each wavelet scale.

.. toctree::
:hidden:
:maxdepth: 2
:caption: Core Functions

scatter
compress
synthesis
generation
representation
compression
optimisation


Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
:html_theme.sidebar_secondary.remove:

**************************
Synthesis
Optimisation
**************************
.. automodule:: s2scat.core.synthesis
.. automodule:: s2scat.optimisation
:members:
7 changes: 7 additions & 0 deletions docs/api/core/representation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:html_theme.sidebar_secondary.remove:

**************************
Representation
**************************
.. automodule:: s2scat.representation
:members:
8 changes: 4 additions & 4 deletions docs/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ Directory structure
.. code-block:: bash
s2scat/
├── core/ # Top-level functionality:
│ ├─ scatter.py # - Scattering covariance transform.
│ ├─ compress.py # - Statistical compression functions.
│ ├─ synthesis.py # - Synthesis optimisation functions.
├── representation.py # - Scattering covariance transform.
├── compression.py # - Statistical compression functions.
├── optimisation.py # - Optimisation algorithm wrappers.
├── generation.py # - Latent encoder and Generative decoder.
├── operators/ # Internal functionality:
│ ├─ spherical.py # - Specific spherical operations, e.g. batched SHTs.
Expand Down
4 changes: 0 additions & 4 deletions docs/api/utility/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ Helper Functions
- Computes the mean and variance of spherical harmonic coefficients :math:`f_{\ell m}`.
* - :func:`~s2scat.utility.statistics.normalize_map`
- Normalises a spherical map to zero mean and unit variance.
* - :func:`~s2scat.utility.statistics.compute_P00`
- Stand alone function to compute the second order power statistics.
* - :func:`~s2scat.utility.statistics.compute_C01_and_C11`
- Stand alone function to compute the fourth and sixth order covariance statistics.
* - :func:`~s2scat.utility.statistics.add_to_S1`
Expand All @@ -35,8 +33,6 @@ Helper Functions
- Computes and appends the fourth order covariance statistic :math:`\text{C01}_j = \text{Cov}\big [ \Psi^{\lambda_1} f, \Psi^{\lambda_1} | \Psi^{\lambda_2} f | \big ]` at scale :math:`j`.
* - :func:`~s2scat.utility.statistics.add_to_C11`
- Computes and appends the sixth order covariance statistic :math:`\text{C11}_j = \text{Cov}\big [ \Psi^{\lambda_1} | \Psi^{\lambda_3} f |, \Psi^{\lambda_1} | \Psi^{\lambda_2} f | \big ]` at scale :math:`j`.
* - :func:`~s2scat.utility.statistics.apply_norm`
- Applies normalisation to a complete list of covariance statistics.

.. list-table:: Statistical Normalisation Functions
:widths: 25 25
Expand Down
Loading

0 comments on commit b049934

Please sign in to comment.