Skip to content

Commit 432d4f8

Browse files
committed
Generate HIP global symbol stub redirection
1 parent be82182 commit 432d4f8

File tree

5 files changed

+148
-52
lines changed

5 files changed

+148
-52
lines changed

clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp

Lines changed: 97 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,35 @@
1414

1515
#include "CIRGenCUDARuntime.h"
1616
#include "CIRGenFunction.h"
17+
#include "mlir/IR/Operation.h"
1718
#include "clang/Basic/Cuda.h"
1819
#include "clang/CIR/Dialect/IR/CIRTypes.h"
20+
#include "llvm/Support/Casting.h"
21+
#include "llvm/Support/raw_ostream.h"
22+
#include <iostream>
1923

2024
using namespace clang;
2125
using namespace clang::CIRGen;
2226

2327
CIRGenCUDARuntime::~CIRGenCUDARuntime() {}
2428

29+
CIRGenCUDARuntime::CIRGenCUDARuntime(CIRGenModule &cgm) : cgm(cgm) {
30+
if (cgm.getLangOpts().OffloadViaLLVM)
31+
llvm_unreachable("NYI");
32+
else if (cgm.getLangOpts().HIP)
33+
Prefix = "hip";
34+
else
35+
Prefix = "cuda";
36+
}
37+
38+
std::string CIRGenCUDARuntime::addPrefixToName(StringRef FuncName) const {
39+
return (Prefix + FuncName).str();
40+
}
41+
std::string
42+
CIRGenCUDARuntime::addUnderscoredPrefixToName(StringRef FuncName) const {
43+
return ("__" + Prefix + FuncName).str();
44+
}
45+
2546
void CIRGenCUDARuntime::emitDeviceStubBodyLegacy(CIRGenFunction &cgf,
2647
cir::FuncOp fn,
2748
FunctionArgList &args) {
@@ -31,16 +52,14 @@ void CIRGenCUDARuntime::emitDeviceStubBodyLegacy(CIRGenFunction &cgf,
3152
void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
3253
cir::FuncOp fn,
3354
FunctionArgList &args) {
34-
if (cgm.getLangOpts().HIP)
35-
llvm_unreachable("NYI");
3655

3756
// This requires arguments to be sent to kernels in a different way.
3857
if (cgm.getLangOpts().OffloadViaLLVM)
3958
llvm_unreachable("NYI");
4059

4160
auto &builder = cgm.getBuilder();
4261

43-
// For cudaLaunchKernel, we must add another layer of indirection
62+
// For [cuda|hip]LaunchKernel, we must add another layer of indirection
4463
// to arguments. For example, for function `add(int a, float b)`,
4564
// we need to pass it as `void *args[2] = { &a, &b }`.
4665

@@ -71,7 +90,8 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
7190
LangOptions::GPUDefaultStreamKind::PerThread)
7291
llvm_unreachable("NYI");
7392

74-
std::string launchAPI = "cudaLaunchKernel";
93+
std::string launchAPI = addPrefixToName("LaunchKernel");
94+
std::cout << "LaunchAPI is " << launchAPI << "\n";
7595
const IdentifierInfo &launchII = cgm.getASTContext().Idents.get(launchAPI);
7696
FunctionDecl *launchFD = nullptr;
7797
for (auto *result : dc->lookup(&launchII)) {
@@ -86,11 +106,11 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
86106
}
87107

88108
// Use this function to retrieve arguments for cudaLaunchKernel:
89-
// int __cudaPopCallConfiguration(dim3 *gridDim, dim3 *blockDim, size_t
109+
// int __[cuda|hip]PopCallConfiguration(dim3 *gridDim, dim3 *blockDim, size_t
90110
// *sharedMem, cudaStream_t *stream)
91111
//
92-
// Here cudaStream_t, while also being the 6th argument of cudaLaunchKernel,
93-
// is a pointer to some opaque struct.
112+
// Here [cuda|hip]Stream_t, while also being the 6th argument of
113+
// [cuda|hip]LaunchKernel, is a pointer to some opaque struct.
94114

95115
mlir::Type dim3Ty =
96116
cgf.getTypes().convertType(launchFD->getParamDecl(1)->getType());
@@ -114,26 +134,44 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
114134
cir::FuncType::get({gridDim.getType(), blockDim.getType(),
115135
sharedMem.getType(), stream.getType()},
116136
cgm.SInt32Ty),
117-
"__cudaPopCallConfiguration");
137+
addUnderscoredPrefixToName("PopCallConfiguration"));
118138
cgf.emitRuntimeCall(loc, popConfig, {gridDim, blockDim, sharedMem, stream});
119139

120140
// Now emit the call to cudaLaunchKernel
121-
// cudaError_t cudaLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim,
141+
// [cuda|hip]Error_t [cuda|hip]LaunchKernel(const void *func, dim3 gridDim,
142+
// dim3 blockDim,
122143
// void **args, size_t sharedMem,
123-
// cudaStream_t stream);
124-
auto kernelTy =
125-
cir::PointerType::get(&cgm.getMLIRContext(), fn.getFunctionType());
126-
127-
mlir::Value kernel =
128-
builder.create<cir::GetGlobalOp>(loc, kernelTy, fn.getSymName());
129-
mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);
144+
// [cuda|hip]Stream_t stream);
145+
146+
// We now either pick the function or the stub global for cuda, hip
147+
// resepectively.
148+
auto kernel = [&]() {
149+
if (auto globalOp =
150+
llvm::dyn_cast_or_null<cir::GlobalOp>(KernelHandles[fn])) {
151+
auto kernelTy =
152+
cir::PointerType::get(&cgm.getMLIRContext(), globalOp.getSymType());
153+
mlir::Value kernel = builder.create<cir::GetGlobalOp>(
154+
loc, kernelTy, globalOp.getSymName());
155+
return kernel;
156+
}
157+
if (auto funcOp = llvm::dyn_cast_or_null<cir::FuncOp>(KernelHandles[fn])) {
158+
auto kernelTy = cir::PointerType::get(&cgm.getMLIRContext(),
159+
funcOp.getFunctionType());
160+
mlir::Value kernel =
161+
builder.create<cir::GetGlobalOp>(loc, kernelTy, funcOp.getSymName());
162+
mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);
163+
return func;
164+
}
165+
assert(false && "Expected stub handle to be cir::GlobalOp or funcOp");
166+
}();
167+
// mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);
130168
CallArgList launchArgs;
131169

132170
mlir::Value kernelArgsDecayed =
133171
builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
134172
cir::PointerType::get(cgm.VoidPtrTy));
135173

136-
launchArgs.add(RValue::get(func), launchFD->getParamDecl(0)->getType());
174+
launchArgs.add(RValue::get(kernel), launchFD->getParamDecl(0)->getType());
137175
launchArgs.add(
138176
RValue::getAggregate(Address(gridDim, CharUnits::fromQuantity(8))),
139177
launchFD->getParamDecl(1)->getType());
@@ -157,15 +195,53 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
157195

158196
void CIRGenCUDARuntime::emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
159197
FunctionArgList &args) {
160-
// Device stub and its handle might be different.
161-
if (cgm.getLangOpts().HIP)
162-
llvm_unreachable("NYI");
163-
198+
if (auto globalOp = llvm::dyn_cast<cir::GlobalOp>(KernelHandles[fn])) {
199+
auto symbol = mlir::FlatSymbolRefAttr::get(fn.getSymNameAttr());
200+
// Set the initializer for the global
201+
cgm.setInitializer(globalOp, symbol);
202+
}
164203
// CUDA 9.0 changed the way to launch kernels.
165204
if (CudaFeatureEnabled(cgm.getTarget().getSDKVersion(),
166205
CudaFeature::CUDA_USES_NEW_LAUNCH) ||
206+
(cgm.getLangOpts().HIP && cgm.getLangOpts().HIPUseNewLaunchAPI) ||
167207
cgm.getLangOpts().OffloadViaLLVM)
168208
emitDeviceStubBodyNew(cgf, fn, args);
169209
else
170210
emitDeviceStubBodyLegacy(cgf, fn, args);
171211
}
212+
213+
mlir::Operation *CIRGenCUDARuntime::getKernelHandle(cir::FuncOp fn,
214+
GlobalDecl GD) {
215+
216+
// Check if we already have a kernel handle for this function
217+
auto Loc = KernelHandles.find(fn);
218+
if (Loc != KernelHandles.end())
219+
return Loc->second;
220+
221+
// If not targeting HIP, store the function itself
222+
if (!cgm.getLangOpts().HIP) {
223+
KernelHandles[fn] = fn;
224+
return fn;
225+
}
226+
227+
// Create a new CIR global variable to represent the kernel handle
228+
auto &builder = cgm.getBuilder();
229+
auto globalName = cgm.getMangledName(
230+
GD.getWithKernelReferenceKind(KernelReferenceKind::Kernel));
231+
auto globalOp = cgm.getOrInsertGlobal(
232+
fn->getLoc(), globalName, fn.getFunctionType(), [&] {
233+
return CIRGenModule::createGlobalOp(
234+
cgm, fn->getLoc(), globalName,
235+
builder.getPointerTo(fn.getFunctionType()), true, /* addrSpace=*/{},
236+
/*insertPoint=*/nullptr, fn.getLinkage());
237+
});
238+
239+
globalOp->setAttr("alignment", builder.getI64IntegerAttr(
240+
cgm.getPointerAlign().getQuantity()));
241+
globalOp->setAttr("visibility", fn->getAttr("sym_visibility"));
242+
243+
// Store references
244+
KernelHandles[fn] = globalOp;
245+
246+
return globalOp;
247+
}

clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,28 @@ class FunctionArgList;
2727
class CIRGenCUDARuntime {
2828
protected:
2929
CIRGenModule &cgm;
30+
StringRef Prefix;
31+
32+
// Map a device stub function to a symbol for identifying kernel in host code.
33+
// For CUDA, the symbol for identifying the kernel is the same as the device
34+
// stub function. For HIP, they are different.
35+
llvm::DenseMap<mlir::Operation *, mlir::Operation *> KernelHandles;
3036

3137
private:
3238
void emitDeviceStubBodyLegacy(CIRGenFunction &cgf, cir::FuncOp fn,
3339
FunctionArgList &args);
3440
void emitDeviceStubBodyNew(CIRGenFunction &cgf, cir::FuncOp fn,
3541
FunctionArgList &args);
42+
std::string addPrefixToName(StringRef FuncName) const;
43+
std::string addUnderscoredPrefixToName(StringRef FuncName) const;
3644

3745
public:
38-
CIRGenCUDARuntime(CIRGenModule &cgm) : cgm(cgm) {}
46+
CIRGenCUDARuntime(CIRGenModule &cgm);
3947
virtual ~CIRGenCUDARuntime();
4048

4149
virtual void emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
4250
FunctionArgList &args);
51+
virtual mlir::Operation *getKernelHandle(cir::FuncOp fn, GlobalDecl GD);
4352
};
4453

4554
} // namespace clang::CIRGen

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -651,9 +651,10 @@ void CIRGenModule::emitGlobalFunctionDefinition(GlobalDecl GD,
651651

652652
// Get or create the prototype for the function.
653653
auto Fn = dyn_cast_if_present<cir::FuncOp>(Op);
654-
if (!Fn || Fn.getFunctionType() != Ty)
655-
Fn = GetAddrOfFunction(GD, Ty, /*ForVTable=*/false, /*DontDefer=*/true,
656-
ForDefinition);
654+
if (!Fn || Fn.getFunctionType() != Ty) {
655+
Fn = GetAddrOfFunction(GD, Ty, /*ForVTable=*/false,
656+
/*DontDefer=*/true, ForDefinition);
657+
}
657658

658659
// Already emitted.
659660
if (!Fn.isDeclaration())
@@ -2356,9 +2357,15 @@ cir::FuncOp CIRGenModule::GetAddrOfFunction(clang::GlobalDecl GD, mlir::Type Ty,
23562357

23572358
// As __global__ functions (kernels) always reside on device,
23582359
// when we access them from host, we must refer to the kernel handle.
2359-
// For CUDA, it's just the device stub. For HIP, it's something different.
2360-
if (langOpts.CUDA && !langOpts.CUDAIsDevice && langOpts.HIP &&
2360+
// For HIP, we should never directly access the host device addr, but
2361+
// instead the Global Variable of that stub. For CUDA, it's just the device
2362+
// stub. For HIP, it's something different.
2363+
if ((langOpts.HIP || langOpts.CUDA) && !langOpts.CUDAIsDevice &&
23612364
cast<FunctionDecl>(GD.getDecl())->hasAttr<CUDAGlobalAttr>()) {
2365+
auto *stubHandle = getCUDARuntime().getKernelHandle(F, GD);
2366+
if (IsForDefinition)
2367+
return F;
2368+
23622369
llvm_unreachable("NYI");
23632370
}
23642371

@@ -3169,15 +3176,15 @@ CIRGenModule::GetAddrOfGlobal(GlobalDecl GD, ForDefinition_t IsForDefinition) {
31693176
auto FInfo =
31703177
&getTypes().arrangeCXXMethodDeclaration(cast<CXXMethodDecl>(D));
31713178
auto Ty = getTypes().GetFunctionType(*FInfo);
3172-
return GetAddrOfFunction(GD, Ty, /*ForVTable=*/false, /*DontDefer=*/false,
3173-
IsForDefinition);
3179+
return GetAddrOfFunction(GD, Ty, /*ForVTable=*/false,
3180+
/*DontDefer=*/false, IsForDefinition);
31743181
}
31753182

31763183
if (isa<FunctionDecl>(D)) {
31773184
const CIRGenFunctionInfo &FI = getTypes().arrangeGlobalDeclaration(GD);
31783185
auto Ty = getTypes().GetFunctionType(FI);
3179-
return GetAddrOfFunction(GD, Ty, /*ForVTable=*/false, /*DontDefer=*/false,
3180-
IsForDefinition);
3186+
return GetAddrOfFunction(GD, Ty, /*ForVTable=*/false,
3187+
/*DontDefer=*/false, IsForDefinition);
31813188
}
31823189

31833190
return getAddrOfGlobalVar(cast<VarDecl>(D), /*Ty=*/nullptr, IsForDefinition)

clang/test/CIR/CodeGen/HIP/simple-device.cpp

Lines changed: 0 additions & 14 deletions
This file was deleted.

clang/test/CIR/CodeGen/HIP/simple.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,34 @@
11
#include "../Inputs/cuda.h"
22

3-
// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
4+
// RUN: -x hip -fhip-new-launch-api \
45
// RUN: -emit-cir %s -o %t.cir
5-
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
6+
// RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s
67

8+
// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip \
9+
// RUN: -fcuda-is-device -fhip-new-launch-api \
10+
// RUN: -emit-cir %s -o %t.cir
11+
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s
12+
13+
// Attribute for global_fn
14+
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fni>{{.*}}
715

8-
// This should emit as a normal C++ function.
9-
__host__ void host_fn(int *a, int *b, int *c) {}
1016

11-
// CIR: cir.func @_Z7host_fnPiS_S_
17+
__host__ void host_fn(int *a, int *b, int *c) {}
18+
// CIR-HOST: cir.func @_Z7host_fnPiS_S_
19+
// CIR-DEVICE-NOT: cir.func @_Z7host_fnPiS_S_
1220

13-
// This shouldn't emit.
1421
__device__ void device_fn(int* a, double b, float c) {}
22+
// CIR-HOST-NOT: cir.func @_Z9device_fnPidf
23+
// CIR-DEVICE: cir.func @_Z9device_fnPidf
24+
25+
__global__ void global_fn(int a) {}
26+
// CIR-DEVICE: @_Z9global_fni
27+
28+
// CIR-HOST: cir.alloca {{.*}}"kernel_args"
29+
// CIR-HOST: cir.call @__hipPopCallConfiguration
1530

16-
// CHECK-NOT: cir.func @_Z9device_fnPidf
31+
// Host access the global stub instead of the functiond evice stub.
32+
// The stub has the mangled name of the function
33+
// CIR-HOST: cir.get_global @_Z9global_fni
34+
// CIR-HOST: cir.call @hipLaunchKernel

0 commit comments

Comments
 (0)