Skip to content

Commit 76fa862

Browse files
authored
Merge pull request #195 from astro-informatics/feature/JAX_frontend_for_C++_codes
add jax frontend support for c/c++ sht libraries
2 parents b3d033c + baa412b commit 76fa862

32 files changed

+1576
-335
lines changed

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
pip install jaxlib
3030
pip install -r requirements/requirements-core.txt
3131
pip install -r requirements/requirements-docs.txt
32-
pip install .\[torch\]
32+
pip install .
3333
3434
- name: Build Documentation
3535
run: |

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
python -m pip install --upgrade pip
3131
pip install -r requirements/requirements-tests.txt
3232
pip install -r requirements/requirements-core.txt
33-
pip install .\[torch\]
33+
pip install .
3434
3535
- name: Run tests
3636
run: |

.pip_readme.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
1212
:target: https://github.com/psf/black
1313
.. image:: https://colab.research.google.com/assets/colab-badge.svg
14-
:target: https://colab.research.google.com/drive/1YmJ2ljsF8HBvhPmD4hrYPlyAKc4WPUgq?usp=sharing
14+
:target: https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooksspherical_harmonic_transform.ipynb
1515

1616
Differentiable and accelerated spherical transforms
1717
=================================================================================================================
@@ -31,6 +31,12 @@ As of version 1.0.2 `S2FFT` also provides PyTorch implementations of underlying
3131
precompute transforms. In future releases this support will be extended to our
3232
on-the-fly algorithms.
3333

34+
As of version 1.1.0 `S2FFT` also provides JAX support for existing C/C++ packages,
35+
specifically `HEALPix` and `SSHT`. This works by wrapping python bindings with custom
36+
JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
37+
limited to CPU, however for many applications this is desirable due to memory
38+
constraints.
39+
3440
Documentation
3541
=============
3642
Read the full documentation `here <https://astro-informatics.github.io/s2fft/>`_.

README.md

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
[![image](https://badge.fury.io/py/s2fft.svg)](https://badge.fury.io/py/s2fft)
55
[![image](http://img.shields.io/badge/arXiv-2311.14670-orange.svg?style=flat)](https://arxiv.org/abs/2311.14670)<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
66
[![All Contributors](https://img.shields.io/badge/all_contributors-9-orange.svg?style=flat-square)](#contributors-)<!-- ALL-CONTRIBUTORS-BADGE:END -->
7-
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1YmJ2ljsF8HBvhPmD4hrYPlyAKc4WPUgq?usp=sharing)
7+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/spherical_harmonic_transform.ipynb)
88
<!-- [![image](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -->
99

1010
<img align="left" height="85" width="98" src="./docs/assets/sax_logo.png">
@@ -22,10 +22,20 @@ for adjoint transformations where needed, and comes with different
2222
optimisations (precompute or not) that one may select depending on
2323
available resources and desired angular resolution $L$.
2424

25+
> [!IMPORTANT]
26+
> HEALPix long JIT compile time fixed for CPU! Fix for GPU coming soon.
27+
28+
> [!TIP]
2529
As of version 1.0.2 `S2FFT` also provides PyTorch implementations of underlying
2630
precompute transforms. In future releases this support will be extended to our
2731
on-the-fly algorithms.
2832

33+
> [!TIP]
34+
As of version 1.1.0 `S2FFT` also provides JAX support for existing C/C++ packages,
35+
specifically `HEALPix` and `SSHT`. This works by wrapping python bindings with custom
36+
JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
37+
limited to CPU.
38+
2939
## Algorithms :zap:
3040

3141
`S2FFT` leverages new algorithmic structures that can he highly
@@ -53,7 +63,7 @@ diagram below illustrates the separable spherical harmonic transform
5363
## Sampling :earth_africa:
5464

5565
The structure of the algorithms implemented in `S2FFT` can support any
56-
isolattitude sampling scheme. A number of sampling schemes are currently
66+
isolatitude sampling scheme. A number of sampling schemes are currently
5767
supported.
5868

5969
The equiangular sampling schemes of [McEwen & Wiaux
@@ -73,10 +83,10 @@ so the corresponding harmonic transforms do not achieve machine
7383
precision but exhibit some error. However, the HEALPix sampling provides
7484
pixels of equal areas, which has many practical advantages.
7585

76-
<p align="center"><img src="./docs/assets/figures/spherical_sampling.png" width="500"></p>
86+
<p align="center"><img src="./docs/assets/figures/spherical_sampling.png" width="700"></p>
7787

7888
> [!NOTE]
79-
> For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and will work to improve this in subsequent versions.
89+
> For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and are working to fix it. A fix for CPU execution has now been implemented (see example [notebook](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_HEALPix_backend.html)). Fix for GPU execution is coming soon.
8090
8191
## Installation :computer:
8292

@@ -87,12 +97,7 @@ into the active python environment by [pip](https://pypi.org) when running
8797
``` bash
8898
pip install s2fft
8999
```
90-
This will install all core functionality which includes JAX support. To install `S2FFT`
91-
with PyTorch support run
92-
93-
``` bash
94-
pip install s2fft[torch]
95-
```
100+
This will install all core functionality which includes JAX support (including PyTorch support).
96101

97102
Alternatively, the `S2FFT` package may be installed directly from GitHub by cloning this
98103
repository and then running
@@ -101,16 +106,22 @@ repository and then running
101106
pip install .
102107
```
103108

104-
from the root directory of the repository. To enable PyTorch support you will need to run
109+
from the root directory of the repository.
110+
111+
Unit tests can then be executed to ensure the installation was successful by first installing the test requirements and then running pytest
105112

106113
``` bash
107-
pip install .[torch]
114+
pip install -r requirements/requirements-tests.txt
115+
pytest tests/
108116
```
109117

110-
Unit tests can then be executed to ensure the installation was successful by running
118+
Documentation for the released version is available [here](https://astro-informatics.github.io/s2fft/). To build the documentation locally run
111119

112120
``` bash
113-
pytest tests/
121+
pip install -r requirements/requirements-docs.txt
122+
cd docs
123+
make html
124+
open _build/html/index.html
114125
```
115126

116127
> [!NOTE]
@@ -143,7 +154,29 @@ For further details on usage see the [documentation](https://astro-informatics.g
143154
> [!NOTE]
144155
> We also provide PyTorch support for the precompute version of our transforms. These are called through forward/inverse_torch(). Full PyTorch support will be provided in future releases.
145156
146-
## Benchmarking :hourglass_flowing_sand:
157+
## C/C++ JAX Frontends for SSHT/HEALPix :bulb:
158+
159+
`S2FFT` also provides JAX support for existing C/C++ packages, specifically [`HEALPix`](https://healpix.jpl.nasa.gov) and [`SSHT`](https://github.com/astro-informatics/ssht). This works
160+
by wrapping python bindings with custom JAX frontends. Note that this C/C++ to JAX interoperability is currently limited to CPU.
161+
162+
For example, one may call these alternate backends for the spherical harmonic transform by:
163+
164+
``` python
165+
# Forward SSHT spherical harmonic transform
166+
flm = s2fft.forward(f, L, sampling=["mw"], method="jax_ssht")
167+
168+
# Forward HEALPix spherical harmonic transform
169+
flm = s2fft.forward(f, L, nside=nside, sampling="healpix", method="jax_healpy")
170+
```
171+
172+
All of these JAX frontends supports out of the box reverse mode automatic differentiation,
173+
and under the hood is simply linking to the C/C++ packages you are familiar with. In this
174+
way `S2fft` enhances existing packages with gradient functionality for modern scientific computing or machine learning
175+
applications!
176+
177+
For further details on usage see the associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_SSHT_backend.html).
178+
179+
<!-- ## Benchmarking :hourglass_flowing_sand:
147180
148181
We benchmarked the spherical harmonic and Wigner transforms implemented
149182
in `S2FFT` against the C implementations in the
@@ -167,7 +200,7 @@ that scale linearly with spin).
167200
| 8192 | 82 s | 110.8 | 2.14E-13 | N/A | N/A | N/A | N/A |
168201
169202
where the left hand results are for the recursive based algorithm and the right hand side are
170-
our precompute implementation.
203+
our precompute implementation. -->
171204

172205
## Contributors ✨
173206

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:html_theme.sidebar_secondary.remove:
2+
3+
**************************
4+
C/C++ custom JAX support
5+
**************************
6+
.. automodule:: s2fft.transforms.c_backend_spherical
7+
:members:

docs/api/transforms/index.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,25 @@ Transforms
4141
* - :func:`~s2fft.transforms.wigner.forward_jax`
4242
- Forward Wigner transform (JAX)
4343

44+
.. list-table:: C/C++ backend gradient support
45+
:widths: 25 25
46+
:header-rows: 1
47+
48+
* - Function Name
49+
- Description
50+
* - :func:`~s2fft.transforms.c_backend_spherical.ssht_inverse`
51+
- Custom JAX frontend for inverse SSHT C spherical harmonic library.
52+
* - :func:`~s2fft.transforms.c_backend_spherical.ssht_forward`
53+
- Custom JAX frontend for forward SSHT C spherical harmonic library.
54+
* - :func:`~s2fft.transforms.c_backend_spherical.healpy_inverse`
55+
- Custom JAX frontend for inverse HEALPix C++ spherical harmonic library.
56+
* - :func:`~s2fft.transforms.c_backend_spherical.healpy_forward`
57+
- Custom JAX frontend for forwardHEALPix C++ spherical harmonic library.
58+
* - :func:`~s2fft.transforms.wigner.inverse_jax_ssht`
59+
- Custom JAX frontend for hybrid inverse SSHT C Wigner transforms.
60+
* - :func:`~s2fft.transforms.wigner.forward_jax_ssht`
61+
- Custom JAX frontend for hybrid forward SSHT C Wigner transforms.
62+
4463
.. list-table:: On-the-fly Price-McEwen recursions.
4564
:widths: 25 25
4665
:header-rows: 1
@@ -64,4 +83,5 @@ Transforms
6483
on_the_fly_recursions
6584
spin_spherical_transform
6685
wigner
86+
.. c_backend_spherical
6787

127 KB
Loading

docs/conf.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
author = "Matthew Price, Jason McEwen, Matthew Graham, Sofia Miñano, Devaraj Gopinathan"
2626

2727
# The short X.Y version
28-
version = "1.0.2"
28+
version = "1.1.0"
2929
# The full version, including alpha/beta/rc tags
30-
release = "1.0.2"
30+
release = "1.1.0"
3131

3232

3333
# -- General configuration ---------------------------------------------------
@@ -106,12 +106,12 @@
106106
"icon": "_static/arxiv-logomark-small.png",
107107
"type": "local",
108108
},
109-
# {
110-
# "name": "YouTube",
111-
# "url": "https://www.youtube.com/channel/UCrCOQsyQOJhOUaIYzmbkKQQ",
112-
# "icon": "fa-brands fa-youtube fa-2x",
113-
# "type": "fontawesome",
114-
# },
109+
{
110+
"name": "Medium",
111+
"url": "https://towardsdatascience.com/differentiable-and-accelerated-spherical-harmonic-transforms-c269393d08f1",
112+
"icon": "fa-brands fa-medium",
113+
"type": "fontawesome",
114+
},
115115
{
116116
"name": "PyPi",
117117
"url": "https://pypi.org/project/s2fft/",

docs/index.rst

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,19 @@ transforms (for both real and complex signals), with support for adjoint transfo
1111
where needed, and comes with different optimisations (precompute or not) that one
1212
may select depending on available resources and desired angular resolution :math:`L`.
1313

14-
As of version 1.0.2 ``S2FFT`` also provides PyTorch implementations of underlying
15-
precompute transforms. In future releases this support will be extended to our
16-
on-the-fly algorithms.
14+
.. important::
15+
HEALPix long JIT compile time fixed for CPU! Fix for GPU coming soon.
16+
17+
.. tip::
18+
As of version 1.0.2 ``S2FFT`` also provides PyTorch implementations of underlying
19+
precompute transforms. In future releases this support will be extended to our
20+
on-the-fly algorithms.
21+
22+
.. tip::
23+
As of version 1.1.0 ``S2FFT`` also provides JAX support for existing C/C++ packages,
24+
specifically ``HEALPix`` and ``SSHT``. This works by wrapping python bindings with custom
25+
JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
26+
limited to CPU.
1727

1828
Algorithms |:zap:|
1929
-------------------
@@ -40,7 +50,7 @@ diagram below illustrates the separable spherical harmonic transform.
4050
.. image:: ./assets/figures/sax_schematic_github_docs.png
4151

4252
.. note::
43-
For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and will work to improve this in subsequent versions.
53+
For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and are working to fix it. A fix for CPU execution has now been implemented (see example `notebook <https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_HEALPix_backend.html>`_). Fix for GPU execution is coming soon.
4454

4555
Sampling |:earth_africa:|
4656
-----------------------------------
@@ -53,7 +63,7 @@ The equiangular sampling schemes of `McEwen & Wiaux (2012) <https://arxiv.org/ab
5363
The popular `HEALPix <https://healpix.jpl.nasa.gov>`_ sampling scheme (`Gorski et al. 2005 <https://arxiv.org/abs/astro-ph/0409513>`_) is also supported. The HEALPix sampling does not exhibit a sampling theorem and so the corresponding harmonic transforms do not achieve machine precision but exhibit some error. However, the HEALPix sampling provides pixels of equal areas, which has many practical advantages.
5464

5565
.. image:: ./assets/figures/spherical_sampling.png
56-
:width: 700
66+
:width: 900
5767
:align: center
5868

5969
Contributors ✨
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"path": "../../../notebooks/JAX_HEALPix_frontend.ipynb"
3+
}

0 commit comments

Comments
 (0)