Skip to content

Commit 0018c12

Browse files
authored
[CIR] CallConvLowering for X86 aggregate (#1387)
This deals with some x86 aggregate types for CallConvLowering pass. Suppose we have a simple struct like this. ```cpp struct dim3 { int x, y, z; }; ``` It can be coerced into ```cpp struct dim3_ { uint64_t xy; int z; }; ``` And for a function that receives it as an argument, OG does the following transformation for x86: ```cpp void f(dim3 arg) { /* Before */ } void f(uint64_t xy, int z) { /* After */ } ``` Now this transformation is implemented in the CallConvLowering pass of CIR.
1 parent 60126e9 commit 0018c12

File tree

4 files changed

+134
-5
lines changed

4 files changed

+134
-5
lines changed

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ struct MissingFeatures {
347347
static bool undef() { return false; }
348348
static bool noFPClass() { return false; }
349349
static bool llvmIntrinsicElementTypeSupport() { return false; }
350+
static bool argHasMaybeUndefAttr() { return false; }
350351

351352
//-- Missing parts of the CIRGenModule::Release skeleton.
352353
static bool emitModuleInitializers() { return false; }

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ llvm::LogicalResult LowerFunction::buildFunctionProlog(
612612
Ptr.getLoc(), PointerType::get(STy, ptrType.getAddrSpace()),
613613
CastKind::bitcast, Ptr);
614614
} else {
615-
cir_cconv_unreachable("NYI");
615+
addrToStoreInto = createTmpAlloca(*this, Ptr.getLoc(), STy);
616616
}
617617

618618
assert(STy.getNumElements() == NumIRArgs);
@@ -628,7 +628,7 @@ llvm::LogicalResult LowerFunction::buildFunctionProlog(
628628
}
629629

630630
if (srcSize > dstSize) {
631-
cir_cconv_unreachable("NYI");
631+
createMemCpy(*this, Ptr, addrToStoreInto, dstSize);
632632
}
633633
}
634634
} else {
@@ -1126,9 +1126,47 @@ mlir::Value LowerFunction::rewriteCallOp(const LowerFunctionInfo &CallInfo,
11261126

11271127
// Fast-isel and the optimizer generally like scalar values better than
11281128
// FCAs, so we flatten them if this is safe to do for this argument.
1129+
// As an example, if we have SrcTy = struct { i32, i32, i32 }, then the
1130+
// coerced type can be STy = struct { u64, i32 }. Hence a function with
1131+
// a single argument SrcTy will be rewritten to take two arguments,
1132+
// namely u64 and i32.
11291133
StructType STy = mlir::dyn_cast<StructType>(ArgInfo.getCoerceToType());
11301134
if (STy && ArgInfo.isDirect() && ArgInfo.getCanBeFlattened()) {
1131-
cir_cconv_unreachable("NYI");
1135+
mlir::Type SrcTy = Src.getType();
1136+
llvm::TypeSize SrcTypeSize = LM.getDataLayout().getTypeAllocSize(SrcTy);
1137+
llvm::TypeSize DstTypeSize = LM.getDataLayout().getTypeAllocSize(STy);
1138+
1139+
if (SrcTypeSize.isScalable()) {
1140+
cir_cconv_unreachable("NYI");
1141+
} else {
1142+
size_t SrcSize = SrcTypeSize.getFixedValue();
1143+
size_t DstSize = DstTypeSize.getFixedValue();
1144+
1145+
// Create a new temporary space and copy src in the front bits of it.
1146+
// Other bits will be left untouched.
1147+
// Note in OG, Src is of type Address, while here it is mlir::Value.
1148+
// Here we need to first create another alloca to convert it into a
1149+
// PointerType, so that we can call memcpy.
1150+
if (SrcSize < DstSize) {
1151+
auto Alloca = createTmpAlloca(*this, loc, STy);
1152+
auto SrcAlloca = createTmpAlloca(*this, loc, SrcTy);
1153+
rewriter.create<cir::StoreOp>(loc, Src, SrcAlloca);
1154+
createMemCpy(*this, Alloca, SrcAlloca, SrcSize);
1155+
Src = Alloca;
1156+
} else {
1157+
cir_cconv_unreachable("NYI");
1158+
}
1159+
1160+
assert(NumIRArgs == STy.getNumElements());
1161+
for (unsigned I = 0; I != STy.getNumElements(); ++I) {
1162+
mlir::Value Member = rewriter.create<cir::GetMemberOp>(
1163+
loc, PointerType::get(STy.getMembers()[I]), Src, /*name=*/"",
1164+
/*index=*/I);
1165+
mlir::Value Load = rewriter.create<cir::LoadOp>(loc, Member);
1166+
cir_cconv_assert(!cir::MissingFeatures::argHasMaybeUndefAttr());
1167+
IRCallArgs[FirstIRArg + I] = Load;
1168+
}
1169+
}
11321170
} else {
11331171
// In the simple case, just pass the coerced loaded value.
11341172
cir_cconv_assert(NumIRArgs == 1);

clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/X86.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ void X86_64ABIInfo::classify(mlir::Type Ty, uint64_t OffsetBase, Class &Lo,
182182
return;
183183
} else if (mlir::isa<BoolType>(Ty)) {
184184
Current = Class::Integer;
185+
} else if (mlir::isa<PointerType>(Ty)) {
186+
Current = Class::Integer;
185187
} else if (const auto RT = mlir::dyn_cast<StructType>(Ty)) {
186188
uint64_t Size = getContext().getTypeSize(Ty);
187189

@@ -397,7 +399,11 @@ mlir::Type X86_64ABIInfo::GetINTEGERTypeAtOffset(mlir::Type DestTy,
397399
// returning an 8-byte unit starting with it. See if we can safely use it.
398400
if (IROffset == 0) {
399401
// Pointers and int64's always fill the 8-byte unit.
400-
cir_cconv_assert(!mlir::isa<PointerType>(DestTy) && "Ptrs are NYI");
402+
if (auto ptrTy = mlir::dyn_cast<PointerType>(DestTy)) {
403+
if (ptrTy.getTypeSizeInBits(getDataLayout().layout, {}) == 64)
404+
return DestTy;
405+
cir_cconv_unreachable("NYI");
406+
}
401407

402408
// If we have a 1/2/4-byte integer, we can use it only if the rest of the
403409
// goodness in the source type is just tail padding. This is allowed to
@@ -406,6 +412,10 @@ mlir::Type X86_64ABIInfo::GetINTEGERTypeAtOffset(mlir::Type DestTy,
406412
// have to do this analysis on the source type because we can't depend on
407413
// unions being lowered a specific way etc.
408414
if (auto intTy = mlir::dyn_cast<IntType>(DestTy)) {
415+
// Pointers and int64's always fill the 8-byte unit.
416+
if (intTy.getWidth() == 64)
417+
return DestTy;
418+
409419
if (intTy.getWidth() == 8 || intTy.getWidth() == 16 ||
410420
intTy.getWidth() == 32) {
411421
unsigned BitWidth = intTy.getWidth();

clang/test/CIR/CallConvLowering/x86_64/basic.cpp

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,84 @@ S1 s1(S1 arg) {
125125
// CHECK: %[[#V18:]] = cir.load %[[#V17]] : !cir.ptr<!u64i>, !u64i
126126
// CHECK: cir.return %[[#V18]] : !u64i
127127
return {1, 2};
128-
}
128+
}
129+
130+
/// Test call conv lowering for flattened structs. ///
131+
132+
struct S2 {
133+
int x, y, z;
134+
};
135+
136+
// COM: Function prologue
137+
138+
// CHECK: cir.func @_Z2s22S2(%[[ARG0:[a-z0-9]+]]: !u64i {{.*}}, %[[ARG1:[a-z0-9]+]]: !s32i {{.*}}) -> !ty_anon_struct
139+
// CHECK: %[[#F0:]] = cir.alloca !ty_S2_, !cir.ptr<!ty_S2_>
140+
// CHECK: %[[#F1:]] = cir.alloca !ty_anon_struct, !cir.ptr<!ty_anon_struct>
141+
// CHECK: %[[#F2:]] = cir.get_member %[[#F1]][0]{{.*}} : !cir.ptr<!ty_anon_struct> -> !cir.ptr<!u64i>
142+
// CHECK: cir.store %[[ARG0]], %[[#F2]] : !u64i, !cir.ptr<!u64i>
143+
// CHECK: %[[#F3:]] = cir.get_member %[[#F1]][1]{{.*}} : !cir.ptr<!ty_anon_struct> -> !cir.ptr<!s32i>
144+
// CHECK: cir.store %[[ARG1]], %[[#F3]] : !s32i, !cir.ptr<!s32i>
145+
// CHECK: %[[#F4:]] = cir.cast(bitcast, %[[#F1]] : !cir.ptr<!ty_anon_struct>), !cir.ptr<!void>
146+
// CHECK: %[[#F5:]] = cir.cast(bitcast, %[[#F0]] : !cir.ptr<!ty_S2_>), !cir.ptr<!void>
147+
// CHECK: %[[#F6:]] = cir.const #cir.int<12> : !u64i
148+
// CHECK: cir.libc.memcpy %[[#F6]] bytes from %[[#F4]] to %[[#F5]]
149+
S2 s2(S2 arg) {
150+
// CHECK: %[[#F7:]] = cir.alloca !ty_S2_, !cir.ptr<!ty_S2_>, ["__retval"] {alignment = 4 : i64}
151+
// CHECK: %[[#F8:]] = cir.alloca !ty_S2_, !cir.ptr<!ty_S2_>, ["agg.tmp0"] {alignment = 4 : i64}
152+
// CHECK: %[[#F9:]] = cir.alloca !ty_S2_, !cir.ptr<!ty_S2_>, ["agg.tmp1"] {alignment = 4 : i64}
153+
// CHECK: %[[#F10:]] = cir.alloca !ty_anon_struct, !cir.ptr<!ty_anon_struct>, ["tmp"] {alignment = 8 : i64}
154+
// CHECK: %[[#F11:]] = cir.alloca !ty_S2_, !cir.ptr<!ty_S2_>, ["tmp"] {alignment = 4 : i64}
155+
// CHECK: %[[#F12:]] = cir.alloca !ty_anon_struct, !cir.ptr<!ty_anon_struct>, ["tmp"] {alignment = 8 : i64}
156+
// CHECK: %[[#F13:]] = cir.alloca !ty_anon_struct, !cir.ptr<!ty_anon_struct>, ["tmp"] {alignment = 8 : i64}
157+
158+
// COM: Construction of S2 { 1, 2, 3 }.
159+
160+
// CHECK: %[[#F14:]] = cir.get_member %[[#F8]][0] {{.*}} : !cir.ptr<!ty_S2_> -> !cir.ptr<!s32i>
161+
// CHECK: %[[#F15:]] = cir.const #cir.int<1> : !s32i
162+
// CHECK: cir.store %[[#F15]], %[[#F14]] : !s32i, !cir.ptr<!s32i>
163+
// CHECK: %[[#F16:]] = cir.get_member %[[#F8]][1] {{.*}} : !cir.ptr<!ty_S2_> -> !cir.ptr<!s32i>
164+
// CHECK: %[[#F17:]] = cir.const #cir.int<2> : !s32i
165+
// CHECK: cir.store %[[#F17]], %[[#F16]] : !s32i, !cir.ptr<!s32i>
166+
// CHECK: %[[#F18:]] = cir.get_member %[[#F8]][2] {{.*}} : !cir.ptr<!ty_S2_> -> !cir.ptr<!s32i>
167+
// CHECK: %[[#F19:]] = cir.const #cir.int<3> : !s32i
168+
// CHECK: cir.store %[[#F19]], %[[#F18]] : !s32i, !cir.ptr<!s32i>
169+
170+
// COM: Flattening of the struct.
171+
// COM: { i32, i32, i32 } -> { i64, i32 }.
172+
173+
// CHECK: %[[#F20:]] = cir.load %[[#F8]] : !cir.ptr<!ty_S2_>, !ty_S2_
174+
// CHECK: cir.store %[[#F20]], %[[#F11]] : !ty_S2_, !cir.ptr<!ty_S2_>
175+
// CHECK: %[[#F21:]] = cir.cast(bitcast, %[[#F11]] : !cir.ptr<!ty_S2_>), !cir.ptr<!void>
176+
// CHECK: %[[#F22:]] = cir.cast(bitcast, %[[#F10]] : !cir.ptr<!ty_anon_struct>), !cir.ptr<!void>
177+
// CHECK: %[[#F23:]] = cir.const #cir.int<12> : !u64i
178+
// CHECK: cir.libc.memcpy %[[#F23]] bytes from %[[#F21]] to %[[#F22]]
179+
180+
// COM: Function call.
181+
// COM: Retrieve the two values in { i64, i32 }.
182+
183+
// CHECK: %[[#F24:]] = cir.get_member %[[#F10]][0] {name = ""} : !cir.ptr<!ty_anon_struct> -> !cir.ptr<!u64i>
184+
// CHECK: %[[#F25:]] = cir.load %[[#F24]] : !cir.ptr<!u64i>, !u64i
185+
// CHECK: %[[#F26:]] = cir.get_member %[[#F10]][1] {name = ""} : !cir.ptr<!ty_anon_struct> -> !cir.ptr<!s32i>
186+
// CHECK: %[[#F27:]] = cir.load %[[#F26]] : !cir.ptr<!s32i>, !s32i
187+
// CHECK: %[[#F28:]] = cir.call @_Z2s22S2(%[[#F25]], %[[#F27]]) : (!u64i, !s32i) -> !ty_anon_struct
188+
// CHECK: cir.store %[[#F28]], %[[#F12]] : !ty_anon_struct, !cir.ptr<!ty_anon_struct>
189+
190+
// CHECK: %[[#F29:]] = cir.cast(bitcast, %[[#F12]] : !cir.ptr<!ty_anon_struct>), !cir.ptr<!void>
191+
// CHECK: %[[#F30:]] = cir.cast(bitcast, %[[#F9]] : !cir.ptr<!ty_S2_>), !cir.ptr<!void>
192+
// CHECK: %[[#F31:]] = cir.const #cir.int<12> : !u64i
193+
// CHECK: cir.libc.memcpy %[[#F31]] bytes from %[[#F29]] to %[[#F30]]
194+
195+
// COM: Construct S2 { 1, 2, 3 } again.
196+
// COM: It has been tested above, so no duplication here.
197+
198+
// COM: For return, the first two fields of S2 is also coerced.
199+
200+
// CHECK: %[[#F39:]] = cir.cast(bitcast, %[[#F7]] : !cir.ptr<!ty_S2_>), !cir.ptr<!void>
201+
// CHECK: %[[#F40:]] = cir.cast(bitcast, %[[#F13]] : !cir.ptr<!ty_anon_struct>), !cir.ptr<!void>
202+
// CHECK: %[[#F41:]] = cir.const #cir.int<12> : !u64i
203+
// cir.libc.memcpy %[[#F41]] bytes from %[[#F39]] to %[[#F40]]
204+
// CHECK: %[[#F42:]] = cir.load %[[#F13]] : !cir.ptr<!ty_anon_struct>, !ty_anon_struct
205+
// cir.return %[[#F42]] : !ty_anon_struct
206+
s2({ 1, 2, 3 });
207+
return { 1, 2, 3 };
208+
}

0 commit comments

Comments
 (0)