Skip to content

Commit 3abb71b

Browse files
committed
Use explicit Union rather than | in type hints for Python 3.9 compat
1 parent 6ca389d commit 3abb71b

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

s2fft/utils/jax_primitive.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import Callable
2+
from typing import Callable, Dict, Optional, Union
33
from jax import core
44
from jax.interpreters import ad, batching, xla, mlir
55

@@ -8,10 +8,10 @@ def register_primitive(
88
name: str,
99
multiple_results: bool,
1010
abstract_evaluation: Callable,
11-
lowering_per_platform: dict[None | str, Callable],
12-
batcher: None | Callable = None,
13-
jacobian_vector_product: None | Callable = None,
14-
transpose: None | Callable = None,
11+
lowering_per_platform: Dict[Union[None, str], Callable],
12+
batcher: Optional[Callable] = None,
13+
jacobian_vector_product: Optional[Callable] = None,
14+
transpose: Optional[Callable] = None,
1515
):
1616
"""Register a new custom JAX primitive.
1717

0 commit comments

Comments
 (0)