Skip to content

Commit 922dae6

Browse files
authored
Update README with CUDA extension install details and tweak to Pytorch support note (#302)
1 parent 81781d8 commit 922dae6

File tree

1 file changed

+72
-43
lines changed

1 file changed

+72
-43
lines changed

README.md

Lines changed: 72 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
</div>
1919

2020
`S2FFT` is a Python package for computing Fourier transforms on the sphere
21-
and rotation group [(Price & McEwen 2024)](https://arxiv.org/abs/2311.14670) using
22-
JAX or PyTorch. It leverages autodiff to provide differentiable transforms, which are
21+
and rotation group [(Price & McEwen 2024)](https://arxiv.org/abs/2311.14670) using
22+
JAX or PyTorch. It leverages autodiff to provide differentiable transforms, which are
2323
also deployable on hardware accelerators (e.g. GPUs and TPUs).
2424

2525
More specifically, `S2FFT` provides support for spin spherical harmonic
@@ -41,16 +41,16 @@ angular resolution $L$. The diagram below illustrates the recursions
4141
<img class="dark-light" alt="Schematic of Wigner recursions" src="https://raw.githubusercontent.com/astro-informatics/s2fft/main/docs/assets/figures/Wigner_recursion_legend_darkmode.png" />
4242
</div>
4343

44-
With this recursion to hand, the spherical harmonic coefficients of an
45-
isolatitudinally sampled map may be computed as a two step process. First,
46-
a 1D Fourier transform over longitude, for each latitudinal ring. Second,
47-
a projection onto the real polar-d functions. One may precompute and store
48-
all real polar-d functions for extreme acceleration, however this comes
49-
with an equally extreme memory overhead, which is infeasible at $L \sim 1024$.
50-
Alternatively, the real polar-d functions may calculated recursively,
51-
computing only a portion of the projection at a time, hence incurring
52-
negligible memory overhead at the cost of slightly slower execution. The
53-
diagram below illustrates the separable spherical harmonic transform
44+
With this recursion to hand, the spherical harmonic coefficients of an
45+
isolatitudinally sampled map may be computed as a two step process. First,
46+
a 1D Fourier transform over longitude, for each latitudinal ring. Second,
47+
a projection onto the real polar-d functions. One may precompute and store
48+
all real polar-d functions for extreme acceleration, however this comes
49+
with an equally extreme memory overhead, which is infeasible at $L \sim 1024$.
50+
Alternatively, the real polar-d functions may calculated recursively,
51+
computing only a portion of the projection at a time, hence incurring
52+
negligible memory overhead at the cost of slightly slower execution. The
53+
diagram below illustrates the separable spherical harmonic transform
5454
(for further details see [Price & McEwen 2024]((https://arxiv.org/abs/2311.14670))).
5555

5656
<div style="text-align: center;" align="center">
@@ -65,7 +65,7 @@ supported.
6565

6666
The equiangular sampling schemes of [McEwen & Wiaux
6767
(2012)](https://arxiv.org/abs/1110.6298), [Driscoll & Healy
68-
(1995)](https://www.sciencedirect.com/science/article/pii/S0196885884710086)
68+
(1995)](https://www.sciencedirect.com/science/article/pii/S0196885884710086)
6969
and [Gauss-Legendre (1986)](https://link.springer.com/article/10.1007/BF02519350)
7070
are supported, which exhibit associated sampling theorems and so
7171
harmonic transforms can be computed to machine precision. Note that the
@@ -84,8 +84,17 @@ pixels of equal areas, which has many practical advantages.
8484
<img class="dark-light" alt="Visualization of spherical sampling schemes" src="https://raw.githubusercontent.com/astro-informatics/s2fft/main/docs/assets/figures/spherical_sampling.png" width="700">
8585
</div>
8686

87-
> [!NOTE]
88-
> For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and are working to fix it. A fix for CPU execution has now been implemented (see example [notebook](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_HEALPix_backend.html)).
87+
> [!NOTE]
88+
> For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops.
89+
> After compilation, HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once.
90+
>
91+
> __If running on a CPU__, we provide (differentiable) JAX wrappers of the [`healpy`](https://healpy.readthedocs.io/en/latest/) transforms which can be used to sidestep the issue.
92+
> This implementation can be selected by passing a `method="jax_healpy"` keyword argument to the `s2fft.forward` or `s2fft.inverse` functions -
93+
> see example [notebook](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_HEALPix_backend.html).
94+
>
95+
> __If running on a GPU__, a CUDA extension module is available which avoids the long compilation time.
96+
> This implementation can be selected by passing a `method="jax_cuda"` keyword argument to the `sfft.forward` and `s2fft.inverse` functions.
97+
> Currently we do not publish binary wheels with the CUDA extension support so you will need to [build the package from source](#cuda-extension-support) to use this functionality.
8998
9099
## Installation 💻
91100

@@ -101,16 +110,36 @@ if you wish to install JAX with GPU or TPU support,
101110
you should first follow the [relevant installation instructions in JAX's documentation](https://docs.jax.dev/en/latest/installation.html#installation)
102111
and then install `S2FFT` as above.
103112

104-
Alternatively, the latest development version of `S2FFT` may be installed directly from GitHub by running
113+
Alternatively, the latest development version of `S2FFT` may be installed directly from GitHub by running
105114

106115
```bash
107-
pip install git+https://github.com/astro-informatics/s2fft
116+
pip install git+https://github.com/astro-informatics/s2fft
108117
```
109118

119+
### CUDA extension support
120+
121+
To install the package with support for the CUDA extension module giving reduced compile times for running HEALPix transforms on the GPU,
122+
you will need to build from source on a system with CUDA (tested with version 12.3) and CMake (versions 3.19+) installed.
123+
124+
To install the latest development version from source in verbose mode run
125+
126+
```bash
127+
pip install -v git+https://github.com/astro-informatics/s2fft
128+
```
129+
130+
or to install a specific release tag in verbose mode run
131+
132+
```bash
133+
pip install -v git+https://github.com/astro-informatics/s2fft@TAG
134+
```
135+
136+
where `TAG` is the relevant version tag.
137+
The output should indicate if the CUDA install on your system is successfully detected.
138+
110139
## Tests 🚦
111140

112141
A `pytest` test suite for the package is included in the `tests` directory.
113-
To install the test dependencies, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
142+
To install the test dependencies, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
114143
with the extra test dependencies by running from the root of the repository
115144

116145
```bash
@@ -120,13 +149,13 @@ pip install -e ".[tests]"
120149
To run the tests, run from the root of the repository
121150

122151
```bash
123-
pytest
152+
pytest
124153
```
125154

126155
## Documentation 📖
127156

128-
Documentation for the released version is available [here](https://astro-informatics.github.io/s2fft/).
129-
To install the documentation dependencies, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
157+
Documentation for the released version is available [here](https://astro-informatics.github.io/s2fft/).
158+
To install the documentation dependencies, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
130159
with the extra documentation dependencies by running from the root of the repository
131160

132161
```bash
@@ -136,7 +165,7 @@ pip install -e ".[docs]"
136165
To build the documentation, run from the root of the repository
137166

138167
```bash
139-
cd docs
168+
cd docs
140169
make html
141170
open _build/html/index.html
142171
```
@@ -146,7 +175,7 @@ open _build/html/index.html
146175
A series of tutorial notebooks are included in the `notebooks` directory
147176
and rendered [in the documentation](https://astro-informatics.github.io/s2fft/tutorials/index.html).
148177

149-
To install the dependencies required to run the notebooks locally, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
178+
To install the dependencies required to run the notebooks locally, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
150179
with the extra documentation and plotting dependencies by running from the root of the repository
151180

152181
```bash
@@ -172,12 +201,12 @@ import s2fft
172201
f = ...
173202
L = ...
174203
# Compute harmonic coefficients
175-
flm = s2fft.forward(f, L, method="jax")
204+
flm = s2fft.forward(f, L, method="jax")
176205
# Map back to pixel-space signal
177206
f = s2fft.inverse(flm, L, method="jax")
178207
```
179208

180-
For a signal on the rotation group
209+
For a signal on the rotation group
181210

182211
```python
183212
import s2fft
@@ -194,27 +223,27 @@ f = fft.wigner.inverse_jax(flmn, L, N, method="jax")
194223

195224
For further details on usage see the [documentation](https://astro-informatics.github.io/s2fft/) and associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/spherical_harmonic_transform.html).
196225

197-
> [!NOTE]
198-
> We also provide PyTorch support for our transforms, as demonstrated in the [_Torch frontend_ tutorial notebook](https://astro-informatics.github.io/s2fft/tutorials/torch_frontend/torch_frontend.html). This wraps the JAX implementations so JAX will need to be installed in addition to PyTorch.
226+
We also provide PyTorch support for our transforms, as demonstrated in the [_Torch frontend_ tutorial notebook](https://astro-informatics.github.io/s2fft/tutorials/torch_frontend/torch_frontend.html).
227+
This wraps the JAX implementations so JAX will need to be installed in addition to PyTorch.
199228

200229
## SSHT & HEALPix wrappers 💡
201230

202-
`S2FFT` also provides JAX support for existing C/C++ packages, specifically [`HEALPix`](https://healpix.jpl.nasa.gov) and [`SSHT`](https://github.com/astro-informatics/ssht). This works
231+
`S2FFT` also provides JAX support for existing C/C++ packages, specifically [`HEALPix`](https://healpix.jpl.nasa.gov) and [`SSHT`](https://github.com/astro-informatics/ssht). This works
203232
by wrapping Python bindings with custom JAX frontends. Note that this C/C++ to JAX interoperability is currently limited to CPU.
204233

205234
For example, one may call these alternate backends for the spherical harmonic transform by:
206235

207236
``` python
208237
# Forward SSHT spherical harmonic transform
209-
flm = s2fft.forward(f, L, sampling="mw", method="jax_ssht")
238+
flm = s2fft.forward(f, L, sampling="mw", method="jax_ssht")
210239

211240
# Forward HEALPix spherical harmonic transform
212-
flm = s2fft.forward(f, L, nside=nside, sampling="healpix", method="jax_healpy")
241+
flm = s2fft.forward(f, L, nside=nside, sampling="healpix", method="jax_healpy")
213242
```
214243

215-
All of these JAX frontends supports out of the box reverse mode automatic differentiation,
216-
and under the hood is simply linking to the C/C++ packages you are familiar with. In this
217-
way `S2fft` enhances existing packages with gradient functionality for modern scientific computing or machine learning
244+
All of these JAX frontends supports out of the box reverse mode automatic differentiation,
245+
and under the hood is simply linking to the C/C++ packages you are familiar with. In this
246+
way `S2fft` enhances existing packages with gradient functionality for modern scientific computing or machine learning
218247
applications!
219248

220249
For further details on usage see the associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_SSHT_backend.html).
@@ -264,8 +293,8 @@ patterns!
264293
Should this code be used in any way, we kindly request that the following article is
265294
referenced. A BibTeX entry for this reference may look like:
266295

267-
```
268-
@article{price:s2fft,
296+
```
297+
@article{price:s2fft,
269298
author = "Matthew A. Price and Jason D. McEwen",
270299
title = "Differentiable and accelerated spherical harmonic and Wigner transforms",
271300
journal = "Journal of Computational Physics",
@@ -280,21 +309,21 @@ referenced. A BibTeX entry for this reference may look like:
280309
You might also like to consider citing our related papers on which this
281310
code builds:
282311

283-
```
312+
```
284313
@article{mcewen:fssht,
285314
author = "Jason D. McEwen and Yves Wiaux",
286315
title = "A novel sampling theorem on the sphere",
287316
journal = "IEEE Trans. Sig. Proc.",
288317
year = "2011",
289318
volume = "59",
290319
number = "12",
291-
pages = "5876--5887",
320+
pages = "5876--5887",
292321
eprint = "arXiv:1110.6298",
293322
doi = "10.1109/TSP.2011.2166394"
294323
}
295324
```
296325

297-
```
326+
```
298327
@article{mcewen:so3,
299328
author = "Jason D. McEwen and Martin B{\"u}ttner and Boris ~Leistedt and Hiranya V. Peiris and Yves Wiaux",
300329
title = "A novel sampling theorem on the rotation group",
@@ -304,7 +333,7 @@ code builds:
304333
number = "12",
305334
pages = "2425--2429",
306335
eprint = "arXiv:1508.03101",
307-
doi = "10.1109/LSP.2015.2490676"
336+
doi = "10.1109/LSP.2015.2490676"
308337
}
309338
```
310339

@@ -319,10 +348,10 @@ Copyright 2023 Matthew Price, Jason McEwen and contributors.
319348
details see the [`LICENCE.txt`](https://github.com/astro-informatics/s2fft/blob/main/LICENCE.txt) file.
320349

321350
The file [`lib/include/kernel_helpers.h`](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_helpers.h) is adapted from
322-
[code](https://github.com/dfm/extending-jax/blob/c33869665236877a2ae281f3f5dbff579e8f5b00/lib/kernel_helpers.h) in [a tutorial on extending JAX](https://github.com/dfm/extending-jax) by
351+
[code](https://github.com/dfm/extending-jax/blob/c33869665236877a2ae281f3f5dbff579e8f5b00/lib/kernel_helpers.h) in [a tutorial on extending JAX](https://github.com/dfm/extending-jax) by
323352
[Dan Foreman-Mackey](https://github.com/dfm) and licensed under a [MIT license](https://github.com/dfm/extending-jax/blob/371dca93c6405368fa8e71690afd3968d75f4bac/LICENSE).
324353

325354
The file [`lib/include/kernel_nanobind_helpers.h`](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_nanobind_helpers.h)
326-
is adapted from [code](https://github.com/jax-ml/jax/blob/3d389a7fb440c412d95a1f70ffb91d58408247d0/jaxlib/kernel_nanobind_helpers.h)
327-
by the [JAX](https://github.com/jax-ml/jax) authors
328-
and licensed under a [Apache-2.0 license](https://github.com/jax-ml/jax/blob/3d389a7fb440c412d95a1f70ffb91d58408247d0/LICENSE).
355+
is adapted from [code](https://github.com/jax-ml/jax/blob/3d389a7fb440c412d95a1f70ffb91d58408247d0/jaxlib/kernel_nanobind_helpers.h)
356+
by the [JAX](https://github.com/jax-ml/jax) authors
357+
and licensed under a [Apache-2.0 license](https://github.com/jax-ml/jax/blob/3d389a7fb440c412d95a1f70ffb91d58408247d0/LICENSE).

0 commit comments

Comments
 (0)