Skip to content

Update README with CUDA extension install details and tweak to Pytorch support note #302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 72 additions & 43 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
</div>

`S2FFT` is a Python package for computing Fourier transforms on the sphere
and rotation group [(Price & McEwen 2024)](https://arxiv.org/abs/2311.14670) using
JAX or PyTorch. It leverages autodiff to provide differentiable transforms, which are
and rotation group [(Price & McEwen 2024)](https://arxiv.org/abs/2311.14670) using
JAX or PyTorch. It leverages autodiff to provide differentiable transforms, which are
also deployable on hardware accelerators (e.g. GPUs and TPUs).

More specifically, `S2FFT` provides support for spin spherical harmonic
Expand All @@ -41,16 +41,16 @@ angular resolution $L$. The diagram below illustrates the recursions
<img class="dark-light" alt="Schematic of Wigner recursions" src="https://raw.githubusercontent.com/astro-informatics/s2fft/main/docs/assets/figures/Wigner_recursion_legend_darkmode.png" />
</div>

With this recursion to hand, the spherical harmonic coefficients of an
isolatitudinally sampled map may be computed as a two step process. First,
a 1D Fourier transform over longitude, for each latitudinal ring. Second,
a projection onto the real polar-d functions. One may precompute and store
all real polar-d functions for extreme acceleration, however this comes
with an equally extreme memory overhead, which is infeasible at $L \sim 1024$.
Alternatively, the real polar-d functions may calculated recursively,
computing only a portion of the projection at a time, hence incurring
negligible memory overhead at the cost of slightly slower execution. The
diagram below illustrates the separable spherical harmonic transform
With this recursion to hand, the spherical harmonic coefficients of an
isolatitudinally sampled map may be computed as a two step process. First,
a 1D Fourier transform over longitude, for each latitudinal ring. Second,
a projection onto the real polar-d functions. One may precompute and store
all real polar-d functions for extreme acceleration, however this comes
with an equally extreme memory overhead, which is infeasible at $L \sim 1024$.
Alternatively, the real polar-d functions may calculated recursively,
computing only a portion of the projection at a time, hence incurring
negligible memory overhead at the cost of slightly slower execution. The
diagram below illustrates the separable spherical harmonic transform
(for further details see [Price & McEwen 2024]((https://arxiv.org/abs/2311.14670))).

<div style="text-align: center;" align="center">
Expand All @@ -65,7 +65,7 @@ supported.

The equiangular sampling schemes of [McEwen & Wiaux
(2012)](https://arxiv.org/abs/1110.6298), [Driscoll & Healy
(1995)](https://www.sciencedirect.com/science/article/pii/S0196885884710086)
(1995)](https://www.sciencedirect.com/science/article/pii/S0196885884710086)
and [Gauss-Legendre (1986)](https://link.springer.com/article/10.1007/BF02519350)
are supported, which exhibit associated sampling theorems and so
harmonic transforms can be computed to machine precision. Note that the
Expand All @@ -84,8 +84,17 @@ pixels of equal areas, which has many practical advantages.
<img class="dark-light" alt="Visualization of spherical sampling schemes" src="https://raw.githubusercontent.com/astro-informatics/s2fft/main/docs/assets/figures/spherical_sampling.png" width="700">
</div>

> [!NOTE]
> For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and are working to fix it. A fix for CPU execution has now been implemented (see example [notebook](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_HEALPix_backend.html)).
> [!NOTE]
> For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops.
> After compilation, HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once.
>
> __If running on a CPU__, we provide (differentiable) JAX wrappers of the [`healpy`](https://healpy.readthedocs.io/en/latest/) transforms which can be used to sidestep the issue.
> This implementation can be selected by passing a `method="jax_healpy"` keyword argument to the `s2fft.forward` or `s2fft.inverse` functions -
> see example [notebook](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_HEALPix_backend.html).
>
> __If running on a GPU__, a CUDA extension module is available which avoids the long compilation time.
> This implementation can be selected by passing a `method="jax_cuda"` keyword argument to the `sfft.forward` and `s2fft.inverse` functions.
> Currently we do not publish binary wheels with the CUDA extension support so you will need to [build the package from source](#cuda-extension-support) to use this functionality.

## Installation 💻

Expand All @@ -101,16 +110,36 @@ if you wish to install JAX with GPU or TPU support,
you should first follow the [relevant installation instructions in JAX's documentation](https://docs.jax.dev/en/latest/installation.html#installation)
and then install `S2FFT` as above.

Alternatively, the latest development version of `S2FFT` may be installed directly from GitHub by running
Alternatively, the latest development version of `S2FFT` may be installed directly from GitHub by running

```bash
pip install git+https://github.com/astro-informatics/s2fft
pip install git+https://github.com/astro-informatics/s2fft
```

### CUDA extension support

To install the package with support for the CUDA extension module giving reduced compile times for running HEALPix transforms on the GPU,
you will need to build from source on a system with CUDA (tested with version 12.3) and CMake (versions 3.19+) installed.

To install the latest development version from source in verbose mode run

```bash
pip install -v git+https://github.com/astro-informatics/s2fft
```

or to install a specific release tag in verbose mode run

```bash
pip install -v git+https://github.com/astro-informatics/s2fft@TAG
```

where `TAG` is the relevant version tag.
The output should indicate if the CUDA install on your system is successfully detected.

## Tests 🚦

A `pytest` test suite for the package is included in the `tests` directory.
To install the test dependencies, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
To install the test dependencies, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
with the extra test dependencies by running from the root of the repository

```bash
Expand All @@ -120,13 +149,13 @@ pip install -e ".[tests]"
To run the tests, run from the root of the repository

```bash
pytest
pytest
```

## Documentation 📖

Documentation for the released version is available [here](https://astro-informatics.github.io/s2fft/).
To install the documentation dependencies, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
Documentation for the released version is available [here](https://astro-informatics.github.io/s2fft/).
To install the documentation dependencies, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
with the extra documentation dependencies by running from the root of the repository

```bash
Expand All @@ -136,7 +165,7 @@ pip install -e ".[docs]"
To build the documentation, run from the root of the repository

```bash
cd docs
cd docs
make html
open _build/html/index.html
```
Expand All @@ -146,7 +175,7 @@ open _build/html/index.html
A series of tutorial notebooks are included in the `notebooks` directory
and rendered [in the documentation](https://astro-informatics.github.io/s2fft/tutorials/index.html).

To install the dependencies required to run the notebooks locally, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
To install the dependencies required to run the notebooks locally, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
with the extra documentation and plotting dependencies by running from the root of the repository

```bash
Expand All @@ -172,12 +201,12 @@ import s2fft
f = ...
L = ...
# Compute harmonic coefficients
flm = s2fft.forward(f, L, method="jax")
flm = s2fft.forward(f, L, method="jax")
# Map back to pixel-space signal
f = s2fft.inverse(flm, L, method="jax")
```

For a signal on the rotation group
For a signal on the rotation group

```python
import s2fft
Expand All @@ -194,27 +223,27 @@ f = fft.wigner.inverse_jax(flmn, L, N, method="jax")

For further details on usage see the [documentation](https://astro-informatics.github.io/s2fft/) and associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/spherical_harmonic_transform.html).

> [!NOTE]
> We also provide PyTorch support for our transforms, as demonstrated in the [_Torch frontend_ tutorial notebook](https://astro-informatics.github.io/s2fft/tutorials/torch_frontend/torch_frontend.html). This wraps the JAX implementations so JAX will need to be installed in addition to PyTorch.
We also provide PyTorch support for our transforms, as demonstrated in the [_Torch frontend_ tutorial notebook](https://astro-informatics.github.io/s2fft/tutorials/torch_frontend/torch_frontend.html).
This wraps the JAX implementations so JAX will need to be installed in addition to PyTorch.

## SSHT & HEALPix wrappers 💡

`S2FFT` also provides JAX support for existing C/C++ packages, specifically [`HEALPix`](https://healpix.jpl.nasa.gov) and [`SSHT`](https://github.com/astro-informatics/ssht). This works
`S2FFT` also provides JAX support for existing C/C++ packages, specifically [`HEALPix`](https://healpix.jpl.nasa.gov) and [`SSHT`](https://github.com/astro-informatics/ssht). This works
by wrapping Python bindings with custom JAX frontends. Note that this C/C++ to JAX interoperability is currently limited to CPU.

For example, one may call these alternate backends for the spherical harmonic transform by:

``` python
# Forward SSHT spherical harmonic transform
flm = s2fft.forward(f, L, sampling="mw", method="jax_ssht")
flm = s2fft.forward(f, L, sampling="mw", method="jax_ssht")

# Forward HEALPix spherical harmonic transform
flm = s2fft.forward(f, L, nside=nside, sampling="healpix", method="jax_healpy")
flm = s2fft.forward(f, L, nside=nside, sampling="healpix", method="jax_healpy")
```

All of these JAX frontends supports out of the box reverse mode automatic differentiation,
and under the hood is simply linking to the C/C++ packages you are familiar with. In this
way `S2fft` enhances existing packages with gradient functionality for modern scientific computing or machine learning
All of these JAX frontends supports out of the box reverse mode automatic differentiation,
and under the hood is simply linking to the C/C++ packages you are familiar with. In this
way `S2fft` enhances existing packages with gradient functionality for modern scientific computing or machine learning
applications!

For further details on usage see the associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_SSHT_backend.html).
Expand Down Expand Up @@ -264,8 +293,8 @@ patterns!
Should this code be used in any way, we kindly request that the following article is
referenced. A BibTeX entry for this reference may look like:

```
@article{price:s2fft,
```
@article{price:s2fft,
author = "Matthew A. Price and Jason D. McEwen",
title = "Differentiable and accelerated spherical harmonic and Wigner transforms",
journal = "Journal of Computational Physics",
Expand All @@ -280,21 +309,21 @@ referenced. A BibTeX entry for this reference may look like:
You might also like to consider citing our related papers on which this
code builds:

```
```
@article{mcewen:fssht,
author = "Jason D. McEwen and Yves Wiaux",
title = "A novel sampling theorem on the sphere",
journal = "IEEE Trans. Sig. Proc.",
year = "2011",
volume = "59",
number = "12",
pages = "5876--5887",
pages = "5876--5887",
eprint = "arXiv:1110.6298",
doi = "10.1109/TSP.2011.2166394"
}
```

```
```
@article{mcewen:so3,
author = "Jason D. McEwen and Martin B{\"u}ttner and Boris ~Leistedt and Hiranya V. Peiris and Yves Wiaux",
title = "A novel sampling theorem on the rotation group",
Expand All @@ -304,7 +333,7 @@ code builds:
number = "12",
pages = "2425--2429",
eprint = "arXiv:1508.03101",
doi = "10.1109/LSP.2015.2490676"
doi = "10.1109/LSP.2015.2490676"
}
```

Expand All @@ -319,10 +348,10 @@ Copyright 2023 Matthew Price, Jason McEwen and contributors.
details see the [`LICENCE.txt`](https://github.com/astro-informatics/s2fft/blob/main/LICENCE.txt) file.

The file [`lib/include/kernel_helpers.h`](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_helpers.h) is adapted from
[code](https://github.com/dfm/extending-jax/blob/c33869665236877a2ae281f3f5dbff579e8f5b00/lib/kernel_helpers.h) in [a tutorial on extending JAX](https://github.com/dfm/extending-jax) by
[code](https://github.com/dfm/extending-jax/blob/c33869665236877a2ae281f3f5dbff579e8f5b00/lib/kernel_helpers.h) in [a tutorial on extending JAX](https://github.com/dfm/extending-jax) by
[Dan Foreman-Mackey](https://github.com/dfm) and licensed under a [MIT license](https://github.com/dfm/extending-jax/blob/371dca93c6405368fa8e71690afd3968d75f4bac/LICENSE).

The file [`lib/include/kernel_nanobind_helpers.h`](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_nanobind_helpers.h)
is adapted from [code](https://github.com/jax-ml/jax/blob/3d389a7fb440c412d95a1f70ffb91d58408247d0/jaxlib/kernel_nanobind_helpers.h)
by the [JAX](https://github.com/jax-ml/jax) authors
and licensed under a [Apache-2.0 license](https://github.com/jax-ml/jax/blob/3d389a7fb440c412d95a1f70ffb91d58408247d0/LICENSE).
is adapted from [code](https://github.com/jax-ml/jax/blob/3d389a7fb440c412d95a1f70ffb91d58408247d0/jaxlib/kernel_nanobind_helpers.h)
by the [JAX](https://github.com/jax-ml/jax) authors
and licensed under a [Apache-2.0 license](https://github.com/jax-ml/jax/blob/3d389a7fb440c412d95a1f70ffb91d58408247d0/LICENSE).
Loading