Skip to content

Commit 6092737

Browse files
pt: fix compile error when use intel mpi (#3919)
`MPIX_Query_cuda_support()` is not defined in intelMPI <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced conditional handling for CUDA awareness based on the MPI version to improve compatibility and performance for CUDA-enabled environments. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 57e1f4e commit 6092737

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

source/op/pt/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@ endif()
2020

2121
find_package(MPI)
2222
if(MPI_FOUND)
23+
include(CheckCXXSymbolExists)
24+
set(CMAKE_REQUIRED_INCLUDES ${MPI_CXX_INCLUDE_DIRS})
25+
set(CMAKE_REQUIRED_LIBRARIES ${MPI_CXX_LIBRARIES})
26+
check_cxx_symbol_exists(MPIX_Query_cuda_support "mpi.h" CUDA_AWARE)
27+
if(NOT CUDA_AWARE)
28+
check_cxx_symbol_exists(MPIX_Query_cuda_support "mpi.h;mpi-ext.h" OMP_CUDA)
29+
if(NOT OMP_CUDA)
30+
target_compile_definitions(deepmd_op_pt PRIVATE NO_CUDA_AWARE)
31+
endif()
32+
endif()
2333
target_link_libraries(deepmd_op_pt PRIVATE MPI::MPI_CXX)
2434
target_compile_definitions(deepmd_op_pt PRIVATE USE_MPI)
2535
endif()

source/op/pt/comm.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,11 @@ class Border : public torch::autograd::Function<Border> {
100100
int version, subversion;
101101
MPI_Get_version(&version, &subversion);
102102
if (version >= 4) {
103+
#ifdef NO_CUDA_AWARE
104+
cuda_aware = 0;
105+
#else
103106
cuda_aware = MPIX_Query_cuda_support();
107+
#endif
104108
} else {
105109
cuda_aware = 0;
106110
}
@@ -215,7 +219,11 @@ class Border : public torch::autograd::Function<Border> {
215219
int version, subversion;
216220
MPI_Get_version(&version, &subversion);
217221
if (version >= 4) {
222+
#ifdef NO_CUDA_AWARE
223+
cuda_aware = 0;
224+
#else
218225
cuda_aware = MPIX_Query_cuda_support();
226+
#endif
219227
} else {
220228
cuda_aware = 0;
221229
}

0 commit comments

Comments
 (0)