9
9
import jaxlib .mlir .ir as ir
10
10
from s2fft_lib import _s2fft
11
11
from jaxlib .hlo_helpers import custom_call
12
- from jax ._src .lib .mlir .dialects import hlo
13
- from jax ._src .interpreters import mlir
14
- from jax .core import Primitive , ShapedArray
15
- from jax .interpreters import ad , xla
16
- from jax .sharding import Mesh , NamedSharding
17
- from jax .sharding import PartitionSpec as P
12
+ from jax .core import ShapedArray
18
13
from typing import Tuple
19
- from jax . _src . api import ShapeDtypeStruct
20
- from math import prod
14
+ from jax import ShapeDtypeStruct
15
+ # did not find promote_dtypes_complex outside _src
21
16
from jax ._src .numpy .util import promote_dtypes_complex
22
17
23
18
@@ -709,8 +704,10 @@ def lowering(ctx, f, *, L, nside, reality, fft_type):
709
704
# operand_output_aliases={0: 0},
710
705
backend_config = opaque ,
711
706
)
712
-
713
- return hlo .ReshapeOp (mlir .aval_to_ir_type (aval_out ), result ).results
707
+ # For multi GPU healpix fft, I will be using XLA buffer donation functionality
708
+ # Which will return the same shape as the input in the CUDA primitive then it will be reshaped to the output shape
709
+ # return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), result).results
710
+ return result .results
714
711
715
712
@staticmethod
716
713
def impl (f , L , nside , reality , fft_type ):
@@ -719,21 +716,21 @@ def impl(f, L, nside, reality, fft_type):
719
716
720
717
# Multi GPU part
721
718
722
- @staticmethod
723
- def per_shard_impl ():
724
- return NotImplemented
725
-
726
- @staticmethod
727
- def infer_sharding_from_operands (L , nside , reality , fft_type , mesh : Mesh ,
728
- arg_infos : Tuple [ShapeDtypeStruct ],
729
- result_infos : Tuple [ShapedArray ]):
730
- return NotImplemented
731
-
732
- @staticmethod
733
- def partition (L , nside , reality , fft_type , mesh : Mesh ,
734
- arg_shapes : Tuple [Tuple [int ]], result_shape : Tuple [int ]):
735
- return NotImplemented
736
-
719
+ # @staticmethod
720
+ # def per_shard_impl():
721
+ # return NotImplemented
722
+ #
723
+ # @staticmethod
724
+ # def infer_sharding_from_operands(L, nside, reality, fft_type, mesh: Mesh,
725
+ # arg_infos: Tuple[ShapeDtypeStruct],
726
+ # result_infos: Tuple[ShapedArray]):
727
+ # return NotImplemented
728
+ #
729
+ # @staticmethod
730
+ # def partition(L, nside, reality, fft_type, mesh: Mesh,
731
+ # arg_shapes: Tuple[Tuple[int]], result_shape: Tuple[int]):
732
+ # return NotImplemented
733
+ #
737
734
738
735
register_primitive (HealpixFFTPrimitive )
739
736
0 commit comments