Skip to content

Commit cc67bf7

Browse files
authored
[CIR][CUDA] Skeleton of NVPTX target lowering info (#1358)
Added a skeleton of NVPTX target lowering info. This enables lowering of `simple.cu` (as it hardly tests device side functionalities), so a test of LLVM IR is also added onto it.
1 parent a1ab6bf commit cc67bf7

File tree

5 files changed

+94
-0
lines changed

5 files changed

+94
-0
lines changed

clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_clang_library(TargetLowering
1313
TargetInfo.cpp
1414
TargetLoweringInfo.cpp
1515
Targets/AArch64.cpp
16+
Targets/NVPTX.cpp
1617
Targets/SPIR.cpp
1718
Targets/X86.cpp
1819
Targets/LoweringPrepareAArch64CXXABI.cpp

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerModule.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ createTargetLoweringInfo(LowerModule &LM) {
8181
}
8282
case llvm::Triple::spirv64:
8383
return createSPIRVTargetLoweringInfo(LM);
84+
case llvm::Triple::nvptx64:
85+
return createNVPTXTargetLoweringInfo(LM);
8486
default:
8587
cir_cconv_unreachable("ABI NYI");
8688
}

clang/lib/CIR/Dialect/Transforms/TargetLowering/TargetInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ createAArch64TargetLoweringInfo(LowerModule &CGM, cir::AArch64ABIKind AVXLevel);
3030
std::unique_ptr<TargetLoweringInfo>
3131
createSPIRVTargetLoweringInfo(LowerModule &CGM);
3232

33+
std::unique_ptr<TargetLoweringInfo>
34+
createNVPTXTargetLoweringInfo(LowerModule &CGM);
35+
3336
} // namespace cir
3437

3538
#endif // LLVM_CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_TARGETINFO_H
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//===- NVPTX.cpp - TargetInfo for NVPTX -----------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "ABIInfoImpl.h"
10+
#include "LowerFunctionInfo.h"
11+
#include "LowerTypes.h"
12+
#include "TargetInfo.h"
13+
#include "TargetLoweringInfo.h"
14+
#include "clang/CIR/ABIArgInfo.h"
15+
#include "clang/CIR/MissingFeatures.h"
16+
#include "llvm/Support/ErrorHandling.h"
17+
18+
using ABIArgInfo = cir::ABIArgInfo;
19+
using MissingFeature = cir::MissingFeatures;
20+
21+
namespace cir {
22+
23+
//===----------------------------------------------------------------------===//
24+
// NVPTX ABI Implementation
25+
//===----------------------------------------------------------------------===//
26+
27+
namespace {
28+
29+
class NVPTXABIInfo : public ABIInfo {
30+
public:
31+
NVPTXABIInfo(LowerTypes &lt) : ABIInfo(lt) {}
32+
33+
private:
34+
void computeInfo(LowerFunctionInfo &fi) const override {
35+
llvm_unreachable("NYI");
36+
}
37+
};
38+
39+
class NVPTXTargetLoweringInfo : public TargetLoweringInfo {
40+
public:
41+
NVPTXTargetLoweringInfo(LowerTypes &lt)
42+
: TargetLoweringInfo(std::make_unique<NVPTXABIInfo>(lt)) {}
43+
44+
unsigned getTargetAddrSpaceFromCIRAddrSpace(
45+
cir::AddressSpaceAttr addressSpaceAttr) const override {
46+
using Kind = cir::AddressSpaceAttr::Kind;
47+
switch (addressSpaceAttr.getValue()) {
48+
case Kind::offload_private:
49+
return 0;
50+
case Kind::offload_local:
51+
return 3;
52+
case Kind::offload_global:
53+
return 1;
54+
case Kind::offload_constant:
55+
return 2;
56+
case Kind::offload_generic:
57+
return 4;
58+
default:
59+
cir_cconv_unreachable("Unknown CIR address space for this target");
60+
}
61+
}
62+
};
63+
64+
} // namespace
65+
66+
std::unique_ptr<TargetLoweringInfo>
67+
createNVPTXTargetLoweringInfo(LowerModule &lowerModule) {
68+
return std::make_unique<NVPTXTargetLoweringInfo>(lowerModule.getTypes());
69+
}
70+
71+
} // namespace cir

clang/test/CIR/CodeGen/CUDA/simple.cu

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ __global__ void global_fn(int a) {}
3232
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
3333
// CIR-HOST: cir.call @cudaLaunchKernel
3434

35+
// COM: LLVM-HOST: void @_Z24__device_stub__global_fni
36+
// COM: LLVM-HOST: call i32 @__cudaPopCallConfiguration
37+
// COM: LLVM-HOST: call i32 @cudaLaunchKernel(ptr @_Z24__device_stub__global_fni
38+
3539
int main() {
3640
global_fn<<<1, 1>>>(1);
3741
}
@@ -46,3 +50,16 @@ int main() {
4650
// CIR-HOST: [[Arg:%[0-9]+]] = cir.const #cir.int<1>
4751
// CIR-HOST: cir.call @_Z24__device_stub__global_fni([[Arg]])
4852
// CIR-HOST: }
53+
54+
// COM: LLVM-HOST: define dso_local i32 @main
55+
// COM: LLVM-HOST: alloca %struct.dim3
56+
// COM: LLVM-HOST: alloca %struct.dim3
57+
// COM: LLVM-HOST: call void @_ZN4dim3C1Ejjj
58+
// COM: LLVM-HOST: call void @_ZN4dim3C1Ejjj
59+
// COM: LLVM-HOST: [[LLVMConfigOK:%[0-9]+]] = call i32 @__cudaPushCallConfiguration
60+
// COM: LLVM-HOST: br [[LLVMConfigOK]], label %[[Good:[0-9]+]], label [[Bad:[0-9]+]]
61+
// COM: LLVM-HOST: [[Good]]:
62+
// COM: LLVM-HOST: call void @_Z24__device_stub__global_fni
63+
// COM: LLVM-HOST: br label [[Bad]]
64+
// COM: LLVM-HOST: [[Bad]]:
65+
// COM: LLVM-HOST: ret i32

0 commit comments

Comments
 (0)