Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(cuhpx LANGUAGES CXX CUDA)

# ---------- Basics ----------
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# LTO/IPO can drop CUDA fatbins; keep it off for these extensions
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF)

# CUDA archs default; override via -DCMAKE_CUDA_ARCHITECTURES=80;86;90
if(NOT CMAKE_CUDA_ARCHITECTURES OR CMAKE_CUDA_ARCHITECTURES STREQUAL "OFF")
set(CMAKE_CUDA_ARCHITECTURES 90) # Hopper default
endif()
message(STATUS "✅ cuhpx build configured. CUDA archs: ${CMAKE_CUDA_ARCHITECTURES}")

# ---------- Python / Torch ----------
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module)

# Find site-packages so we can locate Torch's CMake configs
execute_process(
COMMAND "${Python_EXECUTABLE}" -c "import sysconfig; print(sysconfig.get_paths()['purelib'])"
OUTPUT_VARIABLE PYTHON_SITE_PACKAGES
OUTPUT_STRIP_TRAILING_WHITESPACE
)
list(APPEND CMAKE_PREFIX_PATH "${PYTHON_SITE_PACKAGES}/torch/share/cmake")

find_package(Torch REQUIRED) # ${TORCH_LIBRARIES}, ${TORCH_CXX_FLAGS}
find_package(pybind11 REQUIRED)

# ---------- CUDA toolchain (modern imported targets) ----------
find_package(CUDAToolkit REQUIRED) # provides CUDA::cudart, etc.

# Optional: Torch's Python shim (resolves at::Tensor pybind casters)
find_library(TORCH_PYTHON_LIBRARY
NAMES torch_python
HINTS "${PYTHON_SITE_PACKAGES}/torch/lib" "${TORCH_INSTALL_PREFIX}/lib"
)
if(TORCH_PYTHON_LIBRARY)
message(STATUS "Found torch_python at: ${TORCH_PYTHON_LIBRARY}")
endif()

include_directories(
${TORCH_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}/src
)

# ---------- Helper: apply safe CUDA flags per target ----------
function(cuhpx_apply_cuda_flags target_name)
target_compile_options(${target_name} PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>
$<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe>
$<$<COMPILE_LANGUAGE:CUDA>:--diag_suppress=20014>
)
set_target_properties(${target_name} PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
INTERPROCEDURAL_OPTIMIZATION FALSE
)
endfunction()

# ==================== cuhpx_fft ====================
set(CUHPX_FFT_SRC
src/harmonic_transform/hpx_fft.cpp
src/harmonic_transform/hpx_fft_cuda.cu
)
pybind11_add_module(cuhpx_fft MODULE ${CUHPX_FFT_SRC})
target_compile_definitions(cuhpx_fft PRIVATE ${TORCH_CXX_FLAGS})
target_link_libraries(cuhpx_fft PRIVATE
${TORCH_LIBRARIES}
CUDA::cudart
Python::Module
)
if(TORCH_PYTHON_LIBRARY)
target_link_libraries(cuhpx_fft PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
set_target_properties(cuhpx_fft PROPERTIES PREFIX "" OUTPUT_NAME "cuhpx_fft")
cuhpx_apply_cuda_flags(cuhpx_fft)

# ==================== cuhpx_remap ====================
set(CUHPX_REMAP_SRC
src/data_remapping/hpx_remapping.cpp
src/data_remapping/hpx_remapping_cuda.cu
)
pybind11_add_module(cuhpx_remap MODULE ${CUHPX_REMAP_SRC})
target_compile_definitions(cuhpx_remap PRIVATE ${TORCH_CXX_FLAGS})
target_link_libraries(cuhpx_remap PRIVATE
${TORCH_LIBRARIES}
CUDA::cudart
Python::Module
)
if(TORCH_PYTHON_LIBRARY)
target_link_libraries(cuhpx_remap PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
set_target_properties(cuhpx_remap PROPERTIES PREFIX "" OUTPUT_NAME "cuhpx_remap")
cuhpx_apply_cuda_flags(cuhpx_remap)

# ---------- Install layout for wheels ----------
# scikit-build-core will install these into site-packages/cuhpx/ for wheels.
install(TARGETS cuhpx_fft cuhpx_remap LIBRARY DESTINATION cuhpx)
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/cuhpx/data
DESTINATION cuhpx
FILES_MATCHING PATTERN "*.fits"
)

2 changes: 1 addition & 1 deletion cuhpx/hpx_remap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import cuhpx_remap
from . import cuhpx_remap
import torch


Expand Down
29 changes: 1 addition & 28 deletions cuhpx/hpx_sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import cuhpx_fft
from . import cuhpx_fft
import numpy as np
import torch
import torch.cuda.nvtx
import torch.nn as nn
from torch.autograd import Function

Expand Down Expand Up @@ -449,13 +448,10 @@ def einsum_with_chunking(x, weights, mmax, xout, nchunk, stream1):
device = torch.device("cuda")
chunk_size = int(weights.size(1) / nchunk + 1) # Adjust this based on your memory constraints

torch.cuda.nvtx.range_push("Allocate memory for chunk")
next_chunk_cpu = torch.empty((weights.size(0), chunk_size, weights.size(2)), dtype=weights.dtype, pin_memory=True)
current_chunk = torch.empty((weights.size(0), chunk_size, weights.size(2)), dtype=weights.dtype, device=device)
next_chunk = torch.empty_like(current_chunk)
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("einsum between x and weights with chunking")

# Create events for synchronization
event_transfer = torch.cuda.Event(blocking=True)
Expand All @@ -475,39 +471,30 @@ def einsum_with_chunking(x, weights, mmax, xout, nchunk, stream1):
if actual_chunk_size != chunk_size:
next_chunk_cpu.resize_((weights.size(0), actual_chunk_size, weights.size(2)))

torch.cuda.nvtx.range_push("CPU copy from weights to pin memory")
next_chunk_cpu.copy_(weights[:, start_i:end_i, :])
torch.cuda.nvtx.range_pop()

with torch.cuda.stream(stream1):
torch.cuda.nvtx.range_push(f"Transfer weights chunk {i}:{end_i} to GPU")
next_chunk[: weights.size(0), : end_i - start_i, :].copy_(next_chunk_cpu, non_blocking=True)
event_transfer.record(stream1)
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push(f"Compute einsum for chunk {i - chunk_size}:{end_i - chunk_size}")
xout[..., start_j:end_j, :, :] = torch.einsum(
'...kmn,mlk->...lmn', x, current_chunk[:, : end_j - start_j, :].to(x.dtype)
)

event_computation.record(torch.cuda.current_stream())
torch.cuda.nvtx.range_pop()
torch.cuda.current_stream().wait_event(event_transfer)

current_chunk, next_chunk = next_chunk, current_chunk
start_j, end_j = start_i, end_i

if start_i < weights.size(1):
torch.cuda.nvtx.range_push("Compute einsum for the last chunk")
xout[..., start_i:end_i, :, :] = torch.einsum(
'...kmn,mlk->...lmn', x, current_chunk[:, : end_i - start_i, :].to(x.dtype)
)
torch.cuda.nvtx.range_pop()

stream1.synchronize()
torch.cuda.current_stream().synchronize()

torch.cuda.nvtx.range_pop() # End of einsum with chunking

return xout

Expand All @@ -523,39 +510,29 @@ def forward(ctx, x, weights, pct, W, mmax, lmax, nside):
ctx.lmax = lmax
ctx.nside = nside

torch.cuda.nvtx.range_push("rfft")
# SHT
if x.dim() == 1:
x = cuhpx_fft.healpix_rfft_class(x, mmax, nside)
else:
x = cuhpx_fft.healpix_rfft_batch(x, mmax, nside)

torch.cuda.nvtx.range_pop()

x = torch.view_as_real(x)

out_shape = list(x.size())
out_shape[-3] = lmax
out_shape[-2] = mmax

torch.cuda.nvtx.range_push("allocate xout")
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("einsum between pct and weights")
weights = pct * weights
torch.cuda.nvtx.range_pop()

if not pct.is_cuda:
torch.cuda.nvtx.range_push("einsum between x and weights using two stream")
nchunk = 12
stream1 = torch.cuda.Stream()
xout = einsum_with_chunking(x, weights, mmax, xout, nchunk, stream1)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push("einsum between x and weights")
xout = torch.einsum('...kmn,mlk->...lmn', x, weights.to(x.dtype))
torch.cuda.nvtx.range_pop()

x = torch.view_as_complex(xout.contiguous())

Expand Down Expand Up @@ -595,18 +572,14 @@ def forward(ctx, x, weights, pct, W, mmax, lmax, nside):

x = torch.view_as_real(x)

torch.cuda.nvtx.range_push("einsum between x and pct")
xs = torch.einsum('...lmn, mlk->...kmn', x, pct.to(x.dtype))
torch.cuda.nvtx.range_pop()

x = torch.view_as_complex(xs.contiguous())

torch.cuda.nvtx.range_push("irfft")
if x.dim() == 2:
x = cuhpx_fft.healpix_irfft_class(x, mmax, nside)
else:
x = cuhpx_fft.healpix_irfft_batch(x, mmax, nside)
torch.cuda.nvtx.range_pop()

return x

Expand Down
38 changes: 26 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
[build-system]
requires = ["torch", "setuptools"]
build-backend = "setuptools.build_meta"
requires = ["scikit-build-core>=0.8", "pybind11>=2.12", "numpy"]
build-backend = "scikit_build_core.build"

[project]
name = "cuHPX"
version = "2025.5.1"
description = "GPU-accelerated utilities for data on HEALPix grids."
readme = "README.md"
license = { file="LICENSE.txt" }
version = "2025.8.1"
description = "CUDA-accelerated HEALPix tools for harmonic transforms and remapping"
authors = [
{ name = "NVIDIA", email = "asubramaniam@nvidia.com" }
{ name = "NVIDIA", email = "asubramaniam@nvidia.com" }
]
readme = "README.md"
license = { file = "LICENSE.txt" }
requires-python = ">=3.8"
dependencies = [
"numpy",
"torch>=2.0.0",
"astropy",
"torch_harmonics",
"numpy",
"astropy",
"torch_harmonics"
# Expect PyTorch preinstalled in the environment; if you want to enforce:
# "torch>=2.4",
]
classifiers = [
"Development Status :: 2 - Pre-Alpha",
Expand All @@ -32,6 +34,18 @@ classifiers = [
[project.urls]
"Homepage" = "https://github.com/NVlabs/cuHPX"

[tool.scikit-build]
wheel.packages = ["cuhpx"]
cmake.minimum-version = "3.18"
cmake.source-dir = "."
# Keep build artifacts in build/, not alongside sources
build-dir = "build/{wheel_tag}"
sdist.include = ["src", "cuhpx", "tests", "CMakeLists.txt", "LICENSE.txt", "README.md"]

[tool.scikit-build.editable]
# Key: makes editable installs load extensions from build/ via a redirect hook
mode = "redirect"
verbose = true

[tool.black]
line-length = 120
Expand Down Expand Up @@ -115,4 +129,4 @@ target-version = 'py38'
"S101", # asserts allowed in tests...
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
]
]
55 changes: 0 additions & 55 deletions setup.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/data_remapping/hpx_remapping.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ torch::Tensor xy2xy_batch(torch::Tensor data_xy_in, const std::string& s_origin,



PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
PYBIND11_MODULE(cuhpx_remap, m) {

m.def("ring2nest", &ring2nest, "Convert ring to nest (CUDA)");
m.def("nest2ring", &nest2ring, "Convert nest to ring (CUDA)");
Expand Down
2 changes: 1 addition & 1 deletion src/harmonic_transform/hpx_fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ torch::Tensor healpix_irfft(torch::Tensor ftm, int L, int nside) {
return f;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
PYBIND11_MODULE(cuhpx_fft, m) {
m.def("healpix_rfft", &healpix_rfft, "HEALPix RFFT");
m.def("healpix_irfft", &healpix_irfft, "HEALPix IRFFT");
m.def("healpix_rfft_cufft", &healpix_rfft_cufft, "HEALPix RFFT with cuFFT");
Expand Down
Loading