Skip to content

Commit bbf33c8

Browse files
committed
removed _src imports when possible
1 parent e237437 commit bbf33c8

File tree

2 files changed

+25
-51
lines changed

2 files changed

+25
-51
lines changed

s2fft/utils/healpix_ffts.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,10 @@
99
import jaxlib.mlir.ir as ir
1010
from s2fft_lib import _s2fft
1111
from jaxlib.hlo_helpers import custom_call
12-
from jax._src.lib.mlir.dialects import hlo
13-
from jax._src.interpreters import mlir
14-
from jax.core import Primitive, ShapedArray
15-
from jax.interpreters import ad, xla
16-
from jax.sharding import Mesh, NamedSharding
17-
from jax.sharding import PartitionSpec as P
12+
from jax.core import ShapedArray
1813
from typing import Tuple
19-
from jax._src.api import ShapeDtypeStruct
20-
from math import prod
14+
from jax import ShapeDtypeStruct
15+
# did not find promote_dtypes_complex outside _src
2116
from jax._src.numpy.util import promote_dtypes_complex
2217

2318

@@ -709,8 +704,10 @@ def lowering(ctx, f, *, L, nside, reality, fft_type):
709704
# operand_output_aliases={0: 0},
710705
backend_config=opaque,
711706
)
712-
713-
return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), result).results
707+
# For multi GPU healpix fft, I will be using XLA buffer donation functionality
708+
# Which will return the same shape as the input in the CUDA primitive then it will be reshaped to the output shape
709+
# return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), result).results
710+
return result.results
714711

715712
@staticmethod
716713
def impl(f, L, nside, reality, fft_type):
@@ -719,21 +716,21 @@ def impl(f, L, nside, reality, fft_type):
719716

720717
# Multi GPU part
721718

722-
@staticmethod
723-
def per_shard_impl():
724-
return NotImplemented
725-
726-
@staticmethod
727-
def infer_sharding_from_operands(L, nside, reality, fft_type, mesh: Mesh,
728-
arg_infos: Tuple[ShapeDtypeStruct],
729-
result_infos: Tuple[ShapedArray]):
730-
return NotImplemented
731-
732-
@staticmethod
733-
def partition(L, nside, reality, fft_type, mesh: Mesh,
734-
arg_shapes: Tuple[Tuple[int]], result_shape: Tuple[int]):
735-
return NotImplemented
736-
719+
# @staticmethod
720+
# def per_shard_impl():
721+
# return NotImplemented
722+
#
723+
# @staticmethod
724+
# def infer_sharding_from_operands(L, nside, reality, fft_type, mesh: Mesh,
725+
# arg_infos: Tuple[ShapeDtypeStruct],
726+
# result_infos: Tuple[ShapedArray]):
727+
# return NotImplemented
728+
#
729+
# @staticmethod
730+
# def partition(L, nside, reality, fft_type, mesh: Mesh,
731+
# arg_shapes: Tuple[Tuple[int]], result_shape: Tuple[int]):
732+
# return NotImplemented
733+
#
737734

738735
register_primitive(HealpixFFTPrimitive)
739736

s2fft/utils/jax_pritimive.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,10 @@
11
from abc import ABCMeta, abstractmethod
2-
from dataclasses import dataclass
32
from functools import partial
4-
from typing import Tuple
5-
6-
import jax.numpy as jnp
7-
from jax import jit, lax
8-
from jax._src.api import ShapeDtypeStruct
9-
from jax._src.core import ShapedArray
10-
from jax._src.typing import Array, ArrayLike
113
from jax.experimental.custom_partitioning import custom_partitioning
12-
from jax.lax import dynamic_slice
13-
from jax.sharding import Mesh, NamedSharding
14-
from jax.sharding import PartitionSpec as P
15-
from abc import ABCMeta, abstractmethod
16-
from dataclasses import dataclass
17-
from typing import Tuple, Sequence, Union, Callable
18-
from functools import partial, reduce
19-
import operator
20-
import os
21-
import warnings
22-
23-
import numpy as np
24-
import jax.numpy as jnp
25-
from jax.lib import xla_client
26-
from jax import core, dtypes
4+
from jax import core
275
from jax.interpreters import xla, mlir
28-
from jax.experimental.custom_partitioning import custom_partitioning
29-
from jax.interpreters.mlir import ir, dtype_to_ir_type
30-
from jax._src.interpreters import batching
6+
from jax.interpreters import batching
7+
# dispatch is not exposed outside of jax._src
318
from jax._src import dispatch
329

3310
# Inspired by https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/cpp_extensions.py

0 commit comments

Comments
 (0)