14
14
15
15
#include " CIRGenCUDARuntime.h"
16
16
#include " CIRGenFunction.h"
17
+ #include " mlir/IR/Operation.h"
17
18
#include " clang/Basic/Cuda.h"
18
19
#include " clang/CIR/Dialect/IR/CIRTypes.h"
20
+ #include " llvm/Support/Casting.h"
21
+ #include " llvm/Support/raw_ostream.h"
22
+ #include < iostream>
19
23
20
24
using namespace clang ;
21
25
using namespace clang ::CIRGen;
22
26
23
27
CIRGenCUDARuntime::~CIRGenCUDARuntime () {}
24
28
29
+ CIRGenCUDARuntime::CIRGenCUDARuntime (CIRGenModule &cgm) : cgm(cgm) {
30
+ if (cgm.getLangOpts ().OffloadViaLLVM )
31
+ llvm_unreachable (" NYI" );
32
+ else if (cgm.getLangOpts ().HIP )
33
+ Prefix = " hip" ;
34
+ else
35
+ Prefix = " cuda" ;
36
+ }
37
+
38
+ std::string CIRGenCUDARuntime::addPrefixToName (StringRef FuncName) const {
39
+ return (Prefix + FuncName).str ();
40
+ }
41
+ std::string
42
+ CIRGenCUDARuntime::addUnderscoredPrefixToName (StringRef FuncName) const {
43
+ return (" __" + Prefix + FuncName).str ();
44
+ }
45
+
25
46
void CIRGenCUDARuntime::emitDeviceStubBodyLegacy (CIRGenFunction &cgf,
26
47
cir::FuncOp fn,
27
48
FunctionArgList &args) {
@@ -31,16 +52,14 @@ void CIRGenCUDARuntime::emitDeviceStubBodyLegacy(CIRGenFunction &cgf,
31
52
void CIRGenCUDARuntime::emitDeviceStubBodyNew (CIRGenFunction &cgf,
32
53
cir::FuncOp fn,
33
54
FunctionArgList &args) {
34
- if (cgm.getLangOpts ().HIP )
35
- llvm_unreachable (" NYI" );
36
55
37
56
// This requires arguments to be sent to kernels in a different way.
38
57
if (cgm.getLangOpts ().OffloadViaLLVM )
39
58
llvm_unreachable (" NYI" );
40
59
41
60
auto &builder = cgm.getBuilder ();
42
61
43
- // For cudaLaunchKernel , we must add another layer of indirection
62
+ // For [cuda|hip]LaunchKernel , we must add another layer of indirection
44
63
// to arguments. For example, for function `add(int a, float b)`,
45
64
// we need to pass it as `void *args[2] = { &a, &b }`.
46
65
@@ -71,7 +90,8 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
71
90
LangOptions::GPUDefaultStreamKind::PerThread)
72
91
llvm_unreachable (" NYI" );
73
92
74
- std::string launchAPI = " cudaLaunchKernel" ;
93
+ std::string launchAPI = addPrefixToName (" LaunchKernel" );
94
+ std::cout << " LaunchAPI is " << launchAPI << " \n " ;
75
95
const IdentifierInfo &launchII = cgm.getASTContext ().Idents .get (launchAPI);
76
96
FunctionDecl *launchFD = nullptr ;
77
97
for (auto *result : dc->lookup (&launchII)) {
@@ -86,11 +106,11 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
86
106
}
87
107
88
108
// Use this function to retrieve arguments for cudaLaunchKernel:
89
- // int __cudaPopCallConfiguration (dim3 *gridDim, dim3 *blockDim, size_t
109
+ // int __[cuda|hip]PopCallConfiguration (dim3 *gridDim, dim3 *blockDim, size_t
90
110
// *sharedMem, cudaStream_t *stream)
91
111
//
92
- // Here cudaStream_t , while also being the 6th argument of cudaLaunchKernel,
93
- // is a pointer to some opaque struct.
112
+ // Here [cuda|hip]Stream_t , while also being the 6th argument of
113
+ // [cuda|hip]LaunchKernel, is a pointer to some opaque struct.
94
114
95
115
mlir::Type dim3Ty =
96
116
cgf.getTypes ().convertType (launchFD->getParamDecl (1 )->getType ());
@@ -114,26 +134,45 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
114
134
cir::FuncType::get ({gridDim.getType (), blockDim.getType (),
115
135
sharedMem.getType (), stream.getType ()},
116
136
cgm.SInt32Ty ),
117
- " __cudaPopCallConfiguration " );
137
+ addUnderscoredPrefixToName ( " PopCallConfiguration " ) );
118
138
cgf.emitRuntimeCall (loc, popConfig, {gridDim, blockDim, sharedMem, stream});
119
139
120
140
// Now emit the call to cudaLaunchKernel
121
- // cudaError_t cudaLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim,
141
+ // [cuda|hip]Error_t [cuda|hip]LaunchKernel(const void *func, dim3 gridDim,
142
+ // dim3 blockDim,
122
143
// void **args, size_t sharedMem,
123
- // cudaStream_t stream);
124
- auto kernelTy =
125
- cir::PointerType::get (&cgm.getMLIRContext (), fn.getFunctionType ());
144
+ // [cuda|hip]Stream_t stream);
126
145
127
- mlir::Value kernel =
128
- builder.create <cir::GetGlobalOp>(loc, kernelTy, fn.getSymName ());
129
- mlir::Value func = builder.createBitcast (kernel, cgm.VoidPtrTy );
146
+ // We now either pick the function or the stub global for cuda, hip
147
+ // resepectively.
148
+ auto kernel = [&]() {
149
+ if (auto globalOp = llvm::dyn_cast_or_null<cir::GlobalOp>(
150
+ KernelHandles[fn.getSymName ()])) {
151
+ auto kernelTy =
152
+ cir::PointerType::get (&cgm.getMLIRContext (), globalOp.getSymType ());
153
+ mlir::Value kernel = builder.create <cir::GetGlobalOp>(
154
+ loc, kernelTy, globalOp.getSymName ());
155
+ return kernel;
156
+ }
157
+ if (auto funcOp = llvm::dyn_cast_or_null<cir::FuncOp>(
158
+ KernelHandles[fn.getSymName ()])) {
159
+ auto kernelTy = cir::PointerType::get (&cgm.getMLIRContext (),
160
+ funcOp.getFunctionType ());
161
+ mlir::Value kernel =
162
+ builder.create <cir::GetGlobalOp>(loc, kernelTy, funcOp.getSymName ());
163
+ mlir::Value func = builder.createBitcast (kernel, cgm.VoidPtrTy );
164
+ return func;
165
+ }
166
+ assert (false && " Expected stub handle to be cir::GlobalOp or funcOp" );
167
+ }();
168
+ // mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);
130
169
CallArgList launchArgs;
131
170
132
171
mlir::Value kernelArgsDecayed =
133
172
builder.createCast (cir::CastKind::array_to_ptrdecay, kernelArgs,
134
173
cir::PointerType::get (cgm.VoidPtrTy ));
135
174
136
- launchArgs.add (RValue::get (func ), launchFD->getParamDecl (0 )->getType ());
175
+ launchArgs.add (RValue::get (kernel ), launchFD->getParamDecl (0 )->getType ());
137
176
launchArgs.add (
138
177
RValue::getAggregate (Address (gridDim, CharUnits::fromQuantity (8 ))),
139
178
launchFD->getParamDecl (1 )->getType ());
@@ -157,13 +196,16 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
157
196
158
197
void CIRGenCUDARuntime::emitDeviceStub (CIRGenFunction &cgf, cir::FuncOp fn,
159
198
FunctionArgList &args) {
160
- // Device stub and its handle might be different.
161
- if (cgm.getLangOpts ().HIP )
162
- llvm_unreachable (" NYI" );
163
-
199
+ if (auto globalOp =
200
+ llvm::dyn_cast<cir::GlobalOp>(KernelHandles[fn.getSymName ()])) {
201
+ auto symbol = mlir::FlatSymbolRefAttr::get (fn.getSymNameAttr ());
202
+ // Set the initializer for the global
203
+ cgm.setInitializer (globalOp, symbol);
204
+ }
164
205
// CUDA 9.0 changed the way to launch kernels.
165
206
if (CudaFeatureEnabled (cgm.getTarget ().getSDKVersion (),
166
207
CudaFeature::CUDA_USES_NEW_LAUNCH) ||
208
+ (cgm.getLangOpts ().HIP && cgm.getLangOpts ().HIPUseNewLaunchAPI ) ||
167
209
cgm.getLangOpts ().OffloadViaLLVM )
168
210
emitDeviceStubBodyNew (cgf, fn, args);
169
211
else
@@ -189,3 +231,57 @@ RValue CIRGenCUDARuntime::emitCUDAKernelCallExpr(CIRGenFunction &cgf,
189
231
190
232
return RValue::get (nullptr );
191
233
}
234
+
235
+ mlir::Operation *CIRGenCUDARuntime::getKernelHandle (cir::FuncOp fn,
236
+ GlobalDecl GD) {
237
+
238
+ // Check if we already have a kernel handle for this function
239
+ auto Loc = KernelHandles.find (fn.getSymName ());
240
+ if (Loc != KernelHandles.end ()) {
241
+ auto OldHandle = Loc->second ;
242
+ // Here we know that the fn did not change. Return it
243
+ if (KernelStubs[OldHandle] == fn)
244
+ return OldHandle;
245
+
246
+ // We've found the function name, but F itself has changed, so we need to
247
+ // update the references.
248
+ if (cgm.getLangOpts ().HIP ) {
249
+ // For HIP compilation the handle itself does not change, so we only need
250
+ // to update the Stub value.
251
+ KernelStubs[OldHandle] = fn;
252
+ return OldHandle;
253
+ }
254
+ // For non-HIP compilation, erase the old Stub and fall-through to creating
255
+ // new entries.
256
+ KernelStubs.erase (OldHandle);
257
+ }
258
+
259
+ // If not targeting HIP, store the function itself
260
+ if (!cgm.getLangOpts ().HIP ) {
261
+ KernelHandles[fn.getSymName ()] = fn;
262
+ KernelStubs[fn] = fn;
263
+ return fn;
264
+ }
265
+
266
+ // Create a new CIR global variable to represent the kernel handle
267
+ auto &builder = cgm.getBuilder ();
268
+ auto globalName = cgm.getMangledName (
269
+ GD.getWithKernelReferenceKind (KernelReferenceKind::Kernel));
270
+ auto globalOp = cgm.getOrInsertGlobal (
271
+ fn->getLoc (), globalName, fn.getFunctionType (), [&] {
272
+ return CIRGenModule::createGlobalOp (
273
+ cgm, fn->getLoc (), globalName,
274
+ builder.getPointerTo (fn.getFunctionType ()), true , /* addrSpace=*/ {},
275
+ /* insertPoint=*/ nullptr , fn.getLinkage ());
276
+ });
277
+
278
+ globalOp->setAttr (" alignment" , builder.getI64IntegerAttr (
279
+ cgm.getPointerAlign ().getQuantity ()));
280
+ globalOp->setAttr (" visibility" , fn->getAttr (" sym_visibility" ));
281
+
282
+ // Store references
283
+ KernelHandles[fn.getSymName ()] = globalOp;
284
+ KernelStubs[globalOp] = fn;
285
+
286
+ return globalOp;
287
+ }
0 commit comments