Skip to content

Commit 1b14d85

Browse files
committed
Generate HIP global symbol stub redirection
1 parent cf491db commit 1b14d85

File tree

5 files changed

+172
-52
lines changed

5 files changed

+172
-52
lines changed

clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp

Lines changed: 116 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,35 @@
1414

1515
#include "CIRGenCUDARuntime.h"
1616
#include "CIRGenFunction.h"
17+
#include "mlir/IR/Operation.h"
1718
#include "clang/Basic/Cuda.h"
1819
#include "clang/CIR/Dialect/IR/CIRTypes.h"
20+
#include "llvm/Support/Casting.h"
21+
#include "llvm/Support/raw_ostream.h"
22+
#include <iostream>
1923

2024
using namespace clang;
2125
using namespace clang::CIRGen;
2226

2327
CIRGenCUDARuntime::~CIRGenCUDARuntime() {}
2428

29+
CIRGenCUDARuntime::CIRGenCUDARuntime(CIRGenModule &cgm) : cgm(cgm) {
30+
if (cgm.getLangOpts().OffloadViaLLVM)
31+
llvm_unreachable("NYI");
32+
else if (cgm.getLangOpts().HIP)
33+
Prefix = "hip";
34+
else
35+
Prefix = "cuda";
36+
}
37+
38+
std::string CIRGenCUDARuntime::addPrefixToName(StringRef FuncName) const {
39+
return (Prefix + FuncName).str();
40+
}
41+
std::string
42+
CIRGenCUDARuntime::addUnderscoredPrefixToName(StringRef FuncName) const {
43+
return ("__" + Prefix + FuncName).str();
44+
}
45+
2546
void CIRGenCUDARuntime::emitDeviceStubBodyLegacy(CIRGenFunction &cgf,
2647
cir::FuncOp fn,
2748
FunctionArgList &args) {
@@ -31,16 +52,14 @@ void CIRGenCUDARuntime::emitDeviceStubBodyLegacy(CIRGenFunction &cgf,
3152
void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
3253
cir::FuncOp fn,
3354
FunctionArgList &args) {
34-
if (cgm.getLangOpts().HIP)
35-
llvm_unreachable("NYI");
3655

3756
// This requires arguments to be sent to kernels in a different way.
3857
if (cgm.getLangOpts().OffloadViaLLVM)
3958
llvm_unreachable("NYI");
4059

4160
auto &builder = cgm.getBuilder();
4261

43-
// For cudaLaunchKernel, we must add another layer of indirection
62+
// For [cuda|hip]LaunchKernel, we must add another layer of indirection
4463
// to arguments. For example, for function `add(int a, float b)`,
4564
// we need to pass it as `void *args[2] = { &a, &b }`.
4665

@@ -71,7 +90,8 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
7190
LangOptions::GPUDefaultStreamKind::PerThread)
7291
llvm_unreachable("NYI");
7392

74-
std::string launchAPI = "cudaLaunchKernel";
93+
std::string launchAPI = addPrefixToName("LaunchKernel");
94+
std::cout << "LaunchAPI is " << launchAPI << "\n";
7595
const IdentifierInfo &launchII = cgm.getASTContext().Idents.get(launchAPI);
7696
FunctionDecl *launchFD = nullptr;
7797
for (auto *result : dc->lookup(&launchII)) {
@@ -86,11 +106,11 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
86106
}
87107

88108
// Use this function to retrieve arguments for cudaLaunchKernel:
89-
// int __cudaPopCallConfiguration(dim3 *gridDim, dim3 *blockDim, size_t
109+
// int __[cuda|hip]PopCallConfiguration(dim3 *gridDim, dim3 *blockDim, size_t
90110
// *sharedMem, cudaStream_t *stream)
91111
//
92-
// Here cudaStream_t, while also being the 6th argument of cudaLaunchKernel,
93-
// is a pointer to some opaque struct.
112+
// Here [cuda|hip]Stream_t, while also being the 6th argument of
113+
// [cuda|hip]LaunchKernel, is a pointer to some opaque struct.
94114

95115
mlir::Type dim3Ty =
96116
cgf.getTypes().convertType(launchFD->getParamDecl(1)->getType());
@@ -114,26 +134,45 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
114134
cir::FuncType::get({gridDim.getType(), blockDim.getType(),
115135
sharedMem.getType(), stream.getType()},
116136
cgm.SInt32Ty),
117-
"__cudaPopCallConfiguration");
137+
addUnderscoredPrefixToName("PopCallConfiguration"));
118138
cgf.emitRuntimeCall(loc, popConfig, {gridDim, blockDim, sharedMem, stream});
119139

120140
// Now emit the call to cudaLaunchKernel
121-
// cudaError_t cudaLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim,
141+
// [cuda|hip]Error_t [cuda|hip]LaunchKernel(const void *func, dim3 gridDim,
142+
// dim3 blockDim,
122143
// void **args, size_t sharedMem,
123-
// cudaStream_t stream);
124-
auto kernelTy =
125-
cir::PointerType::get(&cgm.getMLIRContext(), fn.getFunctionType());
144+
// [cuda|hip]Stream_t stream);
126145

127-
mlir::Value kernel =
128-
builder.create<cir::GetGlobalOp>(loc, kernelTy, fn.getSymName());
129-
mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);
146+
// We now either pick the function or the stub global for cuda, hip
147+
// resepectively.
148+
auto kernel = [&]() {
149+
if (auto globalOp = llvm::dyn_cast_or_null<cir::GlobalOp>(
150+
KernelHandles[fn.getSymName()])) {
151+
auto kernelTy =
152+
cir::PointerType::get(&cgm.getMLIRContext(), globalOp.getSymType());
153+
mlir::Value kernel = builder.create<cir::GetGlobalOp>(
154+
loc, kernelTy, globalOp.getSymName());
155+
return kernel;
156+
}
157+
if (auto funcOp = llvm::dyn_cast_or_null<cir::FuncOp>(
158+
KernelHandles[fn.getSymName()])) {
159+
auto kernelTy = cir::PointerType::get(&cgm.getMLIRContext(),
160+
funcOp.getFunctionType());
161+
mlir::Value kernel =
162+
builder.create<cir::GetGlobalOp>(loc, kernelTy, funcOp.getSymName());
163+
mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);
164+
return func;
165+
}
166+
assert(false && "Expected stub handle to be cir::GlobalOp or funcOp");
167+
}();
168+
// mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);
130169
CallArgList launchArgs;
131170

132171
mlir::Value kernelArgsDecayed =
133172
builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
134173
cir::PointerType::get(cgm.VoidPtrTy));
135174

136-
launchArgs.add(RValue::get(func), launchFD->getParamDecl(0)->getType());
175+
launchArgs.add(RValue::get(kernel), launchFD->getParamDecl(0)->getType());
137176
launchArgs.add(
138177
RValue::getAggregate(Address(gridDim, CharUnits::fromQuantity(8))),
139178
launchFD->getParamDecl(1)->getType());
@@ -157,13 +196,16 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
157196

158197
void CIRGenCUDARuntime::emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
159198
FunctionArgList &args) {
160-
// Device stub and its handle might be different.
161-
if (cgm.getLangOpts().HIP)
162-
llvm_unreachable("NYI");
163-
199+
if (auto globalOp =
200+
llvm::dyn_cast<cir::GlobalOp>(KernelHandles[fn.getSymName()])) {
201+
auto symbol = mlir::FlatSymbolRefAttr::get(fn.getSymNameAttr());
202+
// Set the initializer for the global
203+
cgm.setInitializer(globalOp, symbol);
204+
}
164205
// CUDA 9.0 changed the way to launch kernels.
165206
if (CudaFeatureEnabled(cgm.getTarget().getSDKVersion(),
166207
CudaFeature::CUDA_USES_NEW_LAUNCH) ||
208+
(cgm.getLangOpts().HIP && cgm.getLangOpts().HIPUseNewLaunchAPI) ||
167209
cgm.getLangOpts().OffloadViaLLVM)
168210
emitDeviceStubBodyNew(cgf, fn, args);
169211
else
@@ -189,3 +231,57 @@ RValue CIRGenCUDARuntime::emitCUDAKernelCallExpr(CIRGenFunction &cgf,
189231

190232
return RValue::get(nullptr);
191233
}
234+
235+
mlir::Operation *CIRGenCUDARuntime::getKernelHandle(cir::FuncOp fn,
236+
GlobalDecl GD) {
237+
238+
// Check if we already have a kernel handle for this function
239+
auto Loc = KernelHandles.find(fn.getSymName());
240+
if (Loc != KernelHandles.end()) {
241+
auto OldHandle = Loc->second;
242+
// Here we know that the fn did not change. Return it
243+
if (KernelStubs[OldHandle] == fn)
244+
return OldHandle;
245+
246+
// We've found the function name, but F itself has changed, so we need to
247+
// update the references.
248+
if (cgm.getLangOpts().HIP) {
249+
// For HIP compilation the handle itself does not change, so we only need
250+
// to update the Stub value.
251+
KernelStubs[OldHandle] = fn;
252+
return OldHandle;
253+
}
254+
// For non-HIP compilation, erase the old Stub and fall-through to creating
255+
// new entries.
256+
KernelStubs.erase(OldHandle);
257+
}
258+
259+
// If not targeting HIP, store the function itself
260+
if (!cgm.getLangOpts().HIP) {
261+
KernelHandles[fn.getSymName()] = fn;
262+
KernelStubs[fn] = fn;
263+
return fn;
264+
}
265+
266+
// Create a new CIR global variable to represent the kernel handle
267+
auto &builder = cgm.getBuilder();
268+
auto globalName = cgm.getMangledName(
269+
GD.getWithKernelReferenceKind(KernelReferenceKind::Kernel));
270+
auto globalOp = cgm.getOrInsertGlobal(
271+
fn->getLoc(), globalName, fn.getFunctionType(), [&] {
272+
return CIRGenModule::createGlobalOp(
273+
cgm, fn->getLoc(), globalName,
274+
builder.getPointerTo(fn.getFunctionType()), true, /* addrSpace=*/{},
275+
/*insertPoint=*/nullptr, fn.getLinkage());
276+
});
277+
278+
globalOp->setAttr("alignment", builder.getI64IntegerAttr(
279+
cgm.getPointerAlign().getQuantity()));
280+
globalOp->setAttr("visibility", fn->getAttr("sym_visibility"));
281+
282+
// Store references
283+
KernelHandles[fn.getSymName()] = globalOp;
284+
KernelStubs[globalOp] = fn;
285+
286+
return globalOp;
287+
}

clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,26 @@ class ReturnValueSlot;
2929
class CIRGenCUDARuntime {
3030
protected:
3131
CIRGenModule &cgm;
32+
StringRef Prefix;
33+
34+
// Map a device stub function to a symbol for identifying kernel in host code.
35+
// For CUDA, the symbol for identifying the kernel is the same as the device
36+
// stub function. For HIP, they are different.
37+
llvm::DenseMap<StringRef, mlir::Operation *> KernelHandles;
38+
39+
// Map a kernel handle to the kernel stub.
40+
llvm::DenseMap<mlir::Operation *, mlir::Operation *> KernelStubs;
3241

3342
private:
3443
void emitDeviceStubBodyLegacy(CIRGenFunction &cgf, cir::FuncOp fn,
3544
FunctionArgList &args);
3645
void emitDeviceStubBodyNew(CIRGenFunction &cgf, cir::FuncOp fn,
3746
FunctionArgList &args);
47+
std::string addPrefixToName(StringRef FuncName) const;
48+
std::string addUnderscoredPrefixToName(StringRef FuncName) const;
3849

3950
public:
40-
CIRGenCUDARuntime(CIRGenModule &cgm) : cgm(cgm) {}
51+
CIRGenCUDARuntime(CIRGenModule &cgm);
4152
virtual ~CIRGenCUDARuntime();
4253

4354
virtual void emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
@@ -46,6 +57,7 @@ class CIRGenCUDARuntime {
4657
virtual RValue emitCUDAKernelCallExpr(CIRGenFunction &cgf,
4758
const CUDAKernelCallExpr *expr,
4859
ReturnValueSlot retValue);
60+
virtual mlir::Operation *getKernelHandle(cir::FuncOp fn, GlobalDecl GD);
4961
};
5062

5163
} // namespace clang::CIRGen

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -651,9 +651,10 @@ void CIRGenModule::emitGlobalFunctionDefinition(GlobalDecl GD,
651651

652652
// Get or create the prototype for the function.
653653
auto Fn = dyn_cast_if_present<cir::FuncOp>(Op);
654-
if (!Fn || Fn.getFunctionType() != Ty)
655-
Fn = GetAddrOfFunction(GD, Ty, /*ForVTable=*/false, /*DontDefer=*/true,
656-
ForDefinition);
654+
if (!Fn || Fn.getFunctionType() != Ty) {
655+
Fn = GetAddrOfFunction(GD, Ty, /*ForVTable=*/false,
656+
/*DontDefer=*/true, ForDefinition);
657+
}
657658

658659
// Already emitted.
659660
if (!Fn.isDeclaration())
@@ -2356,10 +2357,17 @@ cir::FuncOp CIRGenModule::GetAddrOfFunction(clang::GlobalDecl GD, mlir::Type Ty,
23562357

23572358
// As __global__ functions (kernels) always reside on device,
23582359
// when we access them from host, we must refer to the kernel handle.
2359-
// For CUDA, it's just the device stub. For HIP, it's something different.
2360-
if (langOpts.CUDA && !langOpts.CUDAIsDevice && langOpts.HIP &&
2360+
// For HIP, we should never directly access the host device addr, but
2361+
// instead the Global Variable of that stub. For CUDA, it's just the device
2362+
// stub. For HIP, it's something different.
2363+
if ((langOpts.HIP || langOpts.CUDA) && !langOpts.CUDAIsDevice &&
23612364
cast<FunctionDecl>(GD.getDecl())->hasAttr<CUDAGlobalAttr>()) {
2362-
llvm_unreachable("NYI");
2365+
auto *stubHandle = getCUDARuntime().getKernelHandle(F, GD);
2366+
if (IsForDefinition)
2367+
return F;
2368+
2369+
if (langOpts.HIP)
2370+
llvm_unreachable("NYI");
23632371
}
23642372

23652373
return F;
@@ -3169,15 +3177,15 @@ CIRGenModule::GetAddrOfGlobal(GlobalDecl GD, ForDefinition_t IsForDefinition) {
31693177
auto FInfo =
31703178
&getTypes().arrangeCXXMethodDeclaration(cast<CXXMethodDecl>(D));
31713179
auto Ty = getTypes().GetFunctionType(*FInfo);
3172-
return GetAddrOfFunction(GD, Ty, /*ForVTable=*/false, /*DontDefer=*/false,
3173-
IsForDefinition);
3180+
return GetAddrOfFunction(GD, Ty, /*ForVTable=*/false,
3181+
/*DontDefer=*/false, IsForDefinition);
31743182
}
31753183

31763184
if (isa<FunctionDecl>(D)) {
31773185
const CIRGenFunctionInfo &FI = getTypes().arrangeGlobalDeclaration(GD);
31783186
auto Ty = getTypes().GetFunctionType(FI);
3179-
return GetAddrOfFunction(GD, Ty, /*ForVTable=*/false, /*DontDefer=*/false,
3180-
IsForDefinition);
3187+
return GetAddrOfFunction(GD, Ty, /*ForVTable=*/false,
3188+
/*DontDefer=*/false, IsForDefinition);
31813189
}
31823190

31833191
return getAddrOfGlobalVar(cast<VarDecl>(D), /*Ty=*/nullptr, IsForDefinition)

clang/test/CIR/CodeGen/HIP/simple-device.cpp

Lines changed: 0 additions & 14 deletions
This file was deleted.

clang/test/CIR/CodeGen/HIP/simple.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,34 @@
11
#include "../Inputs/cuda.h"
22

3-
// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
4+
// RUN: -x hip -fhip-new-launch-api \
45
// RUN: -emit-cir %s -o %t.cir
5-
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
6+
// RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s
67

8+
// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip \
9+
// RUN: -fcuda-is-device -fhip-new-launch-api \
10+
// RUN: -emit-cir %s -o %t.cir
11+
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s
12+
13+
// Attribute for global_fn
14+
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fni>{{.*}}
715

8-
// This should emit as a normal C++ function.
9-
__host__ void host_fn(int *a, int *b, int *c) {}
1016

11-
// CIR: cir.func @_Z7host_fnPiS_S_
17+
__host__ void host_fn(int *a, int *b, int *c) {}
18+
// CIR-HOST: cir.func @_Z7host_fnPiS_S_
19+
// CIR-DEVICE-NOT: cir.func @_Z7host_fnPiS_S_
1220

13-
// This shouldn't emit.
1421
__device__ void device_fn(int* a, double b, float c) {}
22+
// CIR-HOST-NOT: cir.func @_Z9device_fnPidf
23+
// CIR-DEVICE: cir.func @_Z9device_fnPidf
24+
25+
__global__ void global_fn(int a) {}
26+
// CIR-DEVICE: @_Z9global_fni
27+
28+
// CIR-HOST: cir.alloca {{.*}}"kernel_args"
29+
// CIR-HOST: cir.call @__hipPopCallConfiguration
1530

16-
// CHECK-NOT: cir.func @_Z9device_fnPidf
31+
// Host access the global stub instead of the functiond evice stub.
32+
// The stub has the mangled name of the function
33+
// CIR-HOST: cir.get_global @_Z9global_fni
34+
// CIR-HOST: cir.call @hipLaunchKernel

0 commit comments

Comments
 (0)