Skip to content

Commit fe8ac40

Browse files
authored
Update python_requires and test matrix to support Python 3.11+ (#305)
* Update python_requires and test matrix * Ruff autofixes for type hints with 3.11+ features * Use miniforge to install pytorch / healpy on MacOS * Try using conda-pypi to install dependencies on MacOS * Manually specify dependencies to install with conda * Fix pytorch conda package name and skip PyPI dependency install on MacOS * Add tmate step to allow debugging * Remove tmate and use explicit shell * Set explicit shell options as default for job + relax NumPy requirement * Readd upper bound on NumPy version * Exclude Python 3.13 on MacOS from matrix
1 parent 133b92f commit fe8ac40

File tree

17 files changed

+96
-78
lines changed

17 files changed

+96
-78
lines changed

.github/workflows/tests.yml

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,18 @@ jobs:
2525
build:
2626

2727
runs-on: ${{ matrix.os }}
28+
defaults:
29+
run:
30+
shell: bash -el {0}
2831
strategy:
2932
matrix:
30-
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
31-
os: [ubuntu-latest]
32-
include:
33+
python-version: ["3.11", "3.12", "3.13"]
34+
os: [ubuntu-latest, macos-latest]
35+
exclude:
36+
# Skip Python 3.13 on MacOS as 1.20<=numpy<2 requirement inherited from so3
37+
# requiring numpy<2 cannot be resolved there
3338
- os: macos-latest
34-
python-version: "3.8"
39+
python-version: "3.13"
3540
fail-fast: false
3641
env:
3742
CMAKE_POLICY_VERSION_MINIMUM: 3.5
@@ -42,14 +47,34 @@ jobs:
4247
with:
4348
fetch-depth: 0
4449
fetch-tags: true
45-
46-
- name: Set up Python ${{ matrix.python-version }}
50+
51+
- if: matrix.os == 'macos-latest'
52+
name: Set up Miniforge on MacOS
53+
uses: conda-incubator/setup-miniconda@v3
54+
with:
55+
miniforge-version: latest
56+
python-version: ${{ matrix.python-version }}
57+
58+
- if: matrix.os == 'macos-latest'
59+
name: Install dependencies with conda on MacOS
60+
# Avoid OpenMP runtime incompatibility when using PyPI wheels
61+
# by installing torch and healpy using conda
62+
# https://github.com/healpy/healpy/issues/1012
63+
run: |
64+
conda install jax "jax>=0.3.13,<0.6.0" "numpy>=1.20,<2" ducc0 healpy pytorch pytest pytest-cov
65+
python -m pip install --upgrade pip
66+
pip install --no-deps so3 pyssht
67+
pip install --no-deps .
68+
69+
- if: matrix.os != 'macos-latest'
70+
name: Set up Python ${{ matrix.python-version }}
4771
uses: actions/setup-python@v5
4872
with:
4973
python-version: ${{ matrix.python-version }}
5074
cache: pip
5175

52-
- name: Install dependencies
76+
- if: matrix.os != 'macos-latest'
77+
name: Install dependencies
5378
run: |
5479
python -m pip install --upgrade pip
5580
pip install .[tests]

benchmarks/benchmarking.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,10 @@ def _format_results_entry(results_entry: dict) -> str:
252252

253253
def _dict_product(dicts: dict[str, Iterable[Any]]) -> Iterable[dict[str, Any]]:
254254
"""Generator corresponding to Cartesian product of dictionaries."""
255-
return (dict(zip(dicts.keys(), values)) for values in product(*dicts.values()))
255+
return (
256+
dict(zip(dicts.keys(), values, strict=False))
257+
for values in product(*dicts.values())
258+
)
256259

257260

258261
def _parse_value(value: str) -> Any:

benchmarks/plotting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,10 @@ def plot_results_against_bandlimit(
141141
squeeze=False,
142142
)
143143
axes = axes.T if functions_along_columns else axes
144-
for axes_row, function in zip(axes, functions):
144+
for axes_row, function in zip(axes, functions, strict=False):
145145
results = benchmark_results["results"][function]
146146
l_values = np.array([r["parameters"]["L"] for r in results])
147-
for ax, measurement in zip(axes_row, measurements):
147+
for ax, measurement in zip(axes_row, measurements, strict=False):
148148
plot_function, label = _measurement_plot_functions_and_labels[measurement]
149149
try:
150150
plot_function(ax, "L", l_values, results)

pyproject.toml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
requires = [
33
"setuptools",
44
"setuptools-scm",
5-
"scikit-build-core >=0.4.3",
6-
"nanobind >=1.3.2"
5+
"scikit-build-core>=0.4.3",
6+
"nanobind>=1.3.2"
77
]
88
build-backend = "scikit_build_core.build"
99

@@ -16,11 +16,9 @@ authors = [
1616
classifiers = [
1717
"Programming Language :: Python :: 3",
1818
"Programming Language :: Python :: 3 :: Only",
19-
"Programming Language :: Python :: 3.8",
20-
"Programming Language :: Python :: 3.9",
21-
"Programming Language :: Python :: 3.10",
2219
"Programming Language :: Python :: 3.11",
2320
"Programming Language :: Python :: 3.12",
21+
"Programming Language :: Python :: 3.13",
2422
"Operating System :: OS Independent",
2523
"Intended Audience :: Developers",
2624
"Intended Audience :: Science/Research",
@@ -38,7 +36,7 @@ keywords = [
3836
]
3937
name = "s2fft"
4038
readme = "README.md"
41-
requires-python = ">=3.8"
39+
requires-python = ">=3.11"
4240
license.file = "LICENCE.txt"
4341
urls.homepage = "https://github.com/astro-informatics/s2fft"
4442

s2fft/precompute_transforms/construct.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from typing import Tuple
21
from warnings import warn
32

43
import jax
@@ -612,7 +611,7 @@ def wigner_kernel_jax(
612611
wigner_kernel_torch = torch_wrapper.wrap_as_torch_function(wigner_kernel_jax)
613612

614613

615-
def fourier_wigner_kernel(L: int) -> Tuple[np.ndarray, np.ndarray]:
614+
def fourier_wigner_kernel(L: int) -> tuple[np.ndarray, np.ndarray]:
616615
"""
617616
Computes Fourier coefficients of the reduced Wigner d-functions and quadrature
618617
weights upsampled for the forward Fourier-Wigner transform.
@@ -640,7 +639,7 @@ def fourier_wigner_kernel(L: int) -> Tuple[np.ndarray, np.ndarray]:
640639
return deltas, w
641640

642641

643-
def fourier_wigner_kernel_jax(L: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
642+
def fourier_wigner_kernel_jax(L: int) -> tuple[jnp.ndarray, jnp.ndarray]:
644643
"""
645644
Computes Fourier coefficients of the reduced Wigner d-functions and quadrature
646645
weights upsampled for the forward Fourier-Wigner transform (JAX implementation).

s2fft/precompute_transforms/custom_ops.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from functools import partial
2-
from typing import Tuple
32

43
import jax.numpy as jnp
54
import numpy as np
@@ -9,7 +8,7 @@
98
def wigner_subset_to_s2(
109
flmn: np.ndarray,
1110
spins: np.ndarray,
12-
DW: Tuple[np.ndarray, np.ndarray],
11+
DW: tuple[np.ndarray, np.ndarray],
1312
L: int,
1413
sampling: str = "mw",
1514
) -> np.ndarray:
@@ -91,7 +90,7 @@ def wigner_subset_to_s2(
9190
def wigner_subset_to_s2_jax(
9291
flmn: jnp.ndarray,
9392
spins: jnp.ndarray,
94-
DW: Tuple[jnp.ndarray, jnp.ndarray],
93+
DW: tuple[jnp.ndarray, jnp.ndarray],
9594
L: int,
9695
sampling: str = "mw",
9796
) -> jnp.ndarray:
@@ -173,7 +172,7 @@ def wigner_subset_to_s2_jax(
173172
def so3_to_wigner_subset(
174173
f: np.ndarray,
175174
spins: np.ndarray,
176-
DW: Tuple[np.ndarray, np.ndarray],
175+
DW: tuple[np.ndarray, np.ndarray],
177176
L: int,
178177
N: int,
179178
sampling: str = "mw",
@@ -214,7 +213,7 @@ def so3_to_wigner_subset(
214213
def so3_to_wigner_subset_jax(
215214
f: jnp.ndarray,
216215
spins: jnp.ndarray,
217-
DW: Tuple[jnp.ndarray, jnp.ndarray],
216+
DW: tuple[jnp.ndarray, jnp.ndarray],
218217
L: int,
219218
N: int,
220219
sampling: str = "mw",
@@ -257,7 +256,7 @@ def so3_to_wigner_subset_jax(
257256
def s2_to_wigner_subset(
258257
fs: np.ndarray,
259258
spins: np.ndarray,
260-
DW: Tuple[np.ndarray, np.ndarray],
259+
DW: tuple[np.ndarray, np.ndarray],
261260
L: int,
262261
sampling: str = "mw",
263262
) -> np.ndarray:
@@ -343,7 +342,7 @@ def s2_to_wigner_subset(
343342
def s2_to_wigner_subset_jax(
344343
fs: jnp.ndarray,
345344
spins: jnp.ndarray,
346-
DW: Tuple[jnp.ndarray, jnp.ndarray],
345+
DW: tuple[jnp.ndarray, jnp.ndarray],
347346
L: int,
348347
sampling: str = "mw",
349348
) -> jnp.ndarray:

s2fft/precompute_transforms/spherical.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from functools import partial
2-
from typing import Optional
32
from warnings import warn
43

54
import jax.numpy as jnp
@@ -21,11 +20,11 @@ def inverse(
2120
flm: np.ndarray,
2221
L: int,
2322
spin: int = 0,
24-
kernel: Optional[np.ndarray] = None,
23+
kernel: np.ndarray | None = None,
2524
sampling: str = "mw",
2625
reality: bool = False,
2726
method: str = "jax",
28-
nside: Optional[int] = None,
27+
nside: int | None = None,
2928
) -> np.ndarray:
3029
r"""
3130
Compute the inverse spherical harmonic transform via precompute.
@@ -228,11 +227,11 @@ def forward(
228227
f: np.ndarray,
229228
L: int,
230229
spin: int = 0,
231-
kernel: Optional[np.ndarray] = None,
230+
kernel: np.ndarray | None = None,
232231
sampling: str = "mw",
233232
reality: bool = False,
234233
method: str = "jax",
235-
nside: Optional[int] = None,
234+
nside: int | None = None,
236235
iter: int = 0,
237236
) -> np.ndarray:
238237
r"""

s2fft/recursions/price_mcewen.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import warnings
22
from functools import partial
3-
from typing import List
43

54
import jax.lax as lax
65
import jax.numpy as jnp
@@ -19,7 +18,7 @@ def generate_precomputes(
1918
nside: int = None,
2019
forward: bool = False,
2120
L_lower: int = 0,
22-
) -> List[np.ndarray]:
21+
) -> list[np.ndarray]:
2322
r"""
2423
Compute recursion coefficients with :math:`\mathcal{O}(L^3)` memory overhead.
2524
@@ -125,7 +124,7 @@ def generate_precomputes_jax(
125124
forward: bool = False,
126125
L_lower: int = 0,
127126
betas: jnp.ndarray = None,
128-
) -> List[jnp.ndarray]:
127+
) -> list[jnp.ndarray]:
129128
r"""
130129
Compute recursion coefficients with :math:`\mathcal{O}(L^2)` memory overhead.
131130
In practice one could compute these on-the-fly but the memory overhead is
@@ -264,7 +263,7 @@ def generate_precomputes_wigner(
264263
forward: bool = False,
265264
reality: bool = False,
266265
L_lower: int = 0,
267-
) -> List[List[np.ndarray]]:
266+
) -> list[list[np.ndarray]]:
268267
r"""
269268
Compute recursion coefficients with :math:`\mathcal{O}(L^2)` memory overhead.
270269
In practice one could compute these on-the-fly but the memory overhead is
@@ -316,7 +315,7 @@ def generate_precomputes_wigner_jax(
316315
forward: bool = False,
317316
reality: bool = False,
318317
L_lower: int = 0,
319-
) -> List[List[jnp.ndarray]]:
318+
) -> list[list[jnp.ndarray]]:
320319
r"""
321320
Compute recursion coefficients with :math:`\mathcal{O}(L^2)` memory overhead.
322321
In practice one could compute these on-the-fly but the memory overhead is

s2fft/sampling/s2_samples.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Tuple
2-
31
import numpy as np
42

53

@@ -125,7 +123,7 @@ def nphi_equiang(L: int, sampling: str = "mw") -> int:
125123
return 1
126124

127125

128-
def ftm_shape(L: int, sampling: str = "mw", nside: int = None) -> Tuple[int, int]:
126+
def ftm_shape(L: int, sampling: str = "mw", nside: int = None) -> tuple[int, int]:
129127
r"""
130128
Shape of intermediate array, before/after latitudinal step.
131129
@@ -445,7 +443,7 @@ def ring_phase_shift_hp(
445443
return np.exp(sign * 1j * np.arange(m_start_ind, L) * phi_offset)
446444

447445

448-
def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> Tuple[int]:
446+
def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> tuple[int]:
449447
r"""
450448
Shape of spherical signal.
451449
@@ -480,7 +478,7 @@ def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> Tuple[int
480478
return ntheta(L, sampling), nphi_equiang(L, sampling)
481479

482480

483-
def flm_shape(L: int) -> Tuple[int, int]:
481+
def flm_shape(L: int) -> tuple[int, int]:
484482
r"""
485483
Standard shape of harmonic coefficients.
486484

s2fft/sampling/so3_samples.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
from typing import Tuple
2-
31
import numpy as np
42

53
from s2fft.sampling import s2_samples as samples
64

75

86
def f_shape(
97
L: int, N: int, sampling: str = "mw", nside: int = None
10-
) -> Tuple[int, int, int]:
8+
) -> tuple[int, int, int]:
119
r"""
1210
Computes the pixel-space sampling shape for signal on the rotation group
1311
:math:`SO(3)`.
@@ -49,7 +47,7 @@ def f_shape(
4947
raise ValueError(f"Sampling scheme sampling={sampling} not supported")
5048

5149

52-
def flmn_shape(L: int, N: int) -> Tuple[int, int, int]:
50+
def flmn_shape(L: int, N: int) -> tuple[int, int, int]:
5351
r"""
5452
Computes the shape of Wigner coefficients for signal on the rotation group
5553
:math:`SO(3)`.
@@ -69,7 +67,7 @@ def flmn_shape(L: int, N: int) -> Tuple[int, int, int]:
6967

7068
def fnab_shape(
7169
L: int, N: int, sampling: str = "mw", nside: int = None
72-
) -> Tuple[int, int, int]:
70+
) -> tuple[int, int, int]:
7371
r"""
7472
Computes the shape of Wigner coefficients for signal on the rotation group
7573
:math:`SO(3)`.

0 commit comments

Comments
 (0)