Skip to content

Commit 9ca00c4

Browse files
committed
[CIR][Lowering] Transform cir.store of const arrays into cir.copy
Add lowering prepare logic to lower stores to cir.copy. This bring LLVM lowering closer to OG and turns out the rest of the compiler understands memcpys better and generate better assembly code for at least arm64 and x86_64. Note that current lowering to memcpy is only using i32 intrinsic version, this PR does not touch that code and that will be addressed in another PR.
1 parent 5404e9c commit 9ca00c4

File tree

8 files changed

+124
-34
lines changed

8 files changed

+124
-34
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,25 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
386386
return createAlloca(loc, addrType, type, name, alignmentIntAttr);
387387
}
388388

389+
mlir::Value createGetGlobal(mlir::cir::GlobalOp global,
390+
bool threadLocal = false) {
391+
return create<mlir::cir::GetGlobalOp>(
392+
global.getLoc(),
393+
getPointerTo(global.getSymType(), global.getAddrSpaceAttr()),
394+
global.getName(), threadLocal);
395+
}
396+
397+
/// Create a copy with inferred length.
398+
mlir::cir::CopyOp createCopy(mlir::Value dst, mlir::Value src,
399+
bool isVolatile = false) {
400+
return create<mlir::cir::CopyOp>(dst.getLoc(), dst, src, isVolatile);
401+
}
402+
403+
mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst,
404+
mlir::Value src, mlir::Value len) {
405+
return create<mlir::cir::MemCpyOp>(loc, dst, src, len);
406+
}
407+
389408
mlir::Value createSub(mlir::Value lhs, mlir::Value rhs, bool hasNUW = false,
390409
bool hasNSW = false) {
391410
auto op = create<mlir::cir::BinOp>(lhs.getLoc(), lhs.getType(),

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -613,12 +613,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
613613
// --------------------------
614614
//
615615

616-
/// Create a copy with inferred length.
617-
mlir::cir::CopyOp createCopy(mlir::Value dst, mlir::Value src,
618-
bool isVolatile = false) {
619-
return create<mlir::cir::CopyOp>(dst.getLoc(), dst, src, isVolatile);
620-
}
621-
622616
/// Create a break operation.
623617
mlir::cir::BreakOp createBreak(mlir::Location loc) {
624618
return create<mlir::cir::BreakOp>(loc);
@@ -629,11 +623,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
629623
return create<mlir::cir::ContinueOp>(loc);
630624
}
631625

632-
mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst,
633-
mlir::Value src, mlir::Value len) {
634-
return create<mlir::cir::MemCpyOp>(loc, dst, src, len);
635-
}
636-
637626
mlir::Value createNeg(mlir::Value value) {
638627

639628
if (auto intTy = mlir::dyn_cast<mlir::cir::IntType>(value.getType())) {
@@ -764,14 +753,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
764753
addrSpace);
765754
}
766755

767-
mlir::Value createGetGlobal(mlir::cir::GlobalOp global,
768-
bool threadLocal = false) {
769-
return create<mlir::cir::GetGlobalOp>(
770-
global.getLoc(),
771-
getPointerTo(global.getSymType(), global.getAddrSpaceAttr()),
772-
global.getName(), threadLocal);
773-
}
774-
775756
mlir::Value createGetBitfield(mlir::Location loc, mlir::Type resultType,
776757
mlir::Value addr, mlir::Type storageType,
777758
const CIRGenBitFieldInfo &info,

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
8282
void lowerStdFindOp(StdFindOp op);
8383
void lowerIterBeginOp(IterBeginOp op);
8484
void lowerIterEndOp(IterEndOp op);
85+
void lowerToMemCpy(StoreOp op);
8586
void lowerArrayDtor(ArrayDtor op);
8687
void lowerArrayCtor(ArrayCtor op);
8788

@@ -112,6 +113,10 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
112113
mlir::cir::GlobalLinkageKind Linkage =
113114
mlir::cir::GlobalLinkageKind::ExternalLinkage);
114115

116+
/// Track the current number of global array string count for when the symbol
117+
/// has an empty name, and prevent collisions.
118+
uint64_t annonGlobalConstArrayCount = 0;
119+
115120
///
116121
/// AST related
117122
/// -----------
@@ -1029,6 +1034,61 @@ void LoweringPreparePass::lowerArrayDtor(ArrayDtor op) {
10291034
lowerArrayDtorCtorIntoLoop(builder, op, eltTy, op.getAddr(), arrayLen);
10301035
}
10311036

1037+
static std::string getGlobalVarNameForConstString(mlir::cir::StoreOp op,
1038+
uint64_t &cnt) {
1039+
llvm::SmallString<64> finalName;
1040+
llvm::raw_svector_ostream Out(finalName);
1041+
1042+
Out << "__const.";
1043+
if (auto fnOp = op->getParentOfType<mlir::cir::FuncOp>()) {
1044+
Out << fnOp.getSymNameAttr().getValue() << ".";
1045+
} else {
1046+
Out << "module.";
1047+
}
1048+
1049+
auto allocaOp =
1050+
dyn_cast_or_null<mlir::cir::AllocaOp>(op.getAddr().getDefiningOp());
1051+
if (allocaOp && !allocaOp.getName().empty())
1052+
Out << allocaOp.getName();
1053+
else
1054+
Out << cnt++;
1055+
return finalName.c_str();
1056+
}
1057+
1058+
void LoweringPreparePass::lowerToMemCpy(StoreOp op) {
1059+
// Now that basic filter is done, do more checks before proceding with the
1060+
// transformation.
1061+
auto cstOp =
1062+
dyn_cast_if_present<mlir::cir::ConstantOp>(op.getValue().getDefiningOp());
1063+
if (!cstOp)
1064+
return;
1065+
1066+
if (!isa<mlir::cir::ConstArrayAttr>(cstOp.getValue()))
1067+
return;
1068+
CIRBaseBuilderTy builder(getContext());
1069+
1070+
// Create a global which is initialized with the attribute that is either a
1071+
// constant array or struct.
1072+
assert(!::cir::MissingFeatures::unnamedAddr() && "NYI");
1073+
builder.setInsertionPointToStart(&theModule.getBodyRegion().front());
1074+
std::string globalName =
1075+
getGlobalVarNameForConstString(op, annonGlobalConstArrayCount);
1076+
mlir::cir::GlobalOp globalCst = buildRuntimeVariable(
1077+
builder, globalName, op.getLoc(), op.getValue().getType(),
1078+
mlir::cir::GlobalLinkageKind::PrivateLinkage);
1079+
globalCst.setInitialValueAttr(cstOp.getValue());
1080+
globalCst.setConstant(true);
1081+
1082+
// Transform the store into a cir.copy.
1083+
builder.setInsertionPointAfter(op.getOperation());
1084+
mlir::cir::CopyOp memCpy =
1085+
builder.createCopy(op.getAddr(), builder.createGetGlobal(globalCst));
1086+
op->replaceAllUsesWith(memCpy);
1087+
op->erase();
1088+
if (cstOp->getResult(0).getUsers().empty())
1089+
cstOp->erase();
1090+
}
1091+
10321092
void LoweringPreparePass::lowerArrayCtor(ArrayCtor op) {
10331093
CIRBaseBuilderTy builder(getContext());
10341094
builder.setInsertionPointAfter(op.getOperation());
@@ -1122,6 +1182,10 @@ void LoweringPreparePass::runOnOp(Operation *op) {
11221182
lowerArrayCtor(arrayCtor);
11231183
} else if (auto arrayDtor = dyn_cast<ArrayDtor>(op)) {
11241184
lowerArrayDtor(arrayDtor);
1185+
} else if (auto storeOp = dyn_cast<StoreOp>(op)) {
1186+
mlir::Type valTy = storeOp.getValue().getType();
1187+
if (isa<mlir::cir::ArrayType>(valTy) || isa<mlir::cir::StructType>(valTy))
1188+
lowerToMemCpy(storeOp);
11251189
} else if (auto fnOp = dyn_cast<mlir::cir::FuncOp>(op)) {
11261190
if (auto globalCtor = fnOp.getGlobalCtorAttr()) {
11271191
globalCtorList.push_back(globalCtor);
@@ -1145,7 +1209,7 @@ void LoweringPreparePass::runOnOperation() {
11451209
op->walk([&](Operation *op) {
11461210
if (isa<UnaryOp, BinOp, CastOp, ComplexBinOp, CmpThreeWayOp, VAArgOp,
11471211
GlobalOp, DynamicCastOp, StdFindOp, IterEndOp, IterBeginOp,
1148-
ArrayCtor, ArrayDtor, mlir::cir::FuncOp>(op))
1212+
ArrayCtor, ArrayDtor, mlir::cir::FuncOp, StoreOp>(op))
11491213
opsToTransform.push_back(op);
11501214
});
11511215

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,6 +1479,7 @@ class CIRStoreLowering : public mlir::OpConversionPattern<mlir::cir::StoreOp> {
14791479
auto ordering = getLLVMMemOrder(memorder);
14801480
auto alignOpt = op.getAlignment();
14811481
unsigned alignment = 0;
1482+
14821483
if (!alignOpt) {
14831484
const auto llvmTy =
14841485
getTypeConverter()->convertType(op.getValue().getType());

clang/test/CIR/CodeGen/array-init.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-cir %s -o - | FileCheck %s
22

3+
// CHECK-DAG: cir.global "private" constant cir_private @__const.foo.bar = #cir.const_array<[#cir.fp<9.000000e+00> : !cir.double, #cir.fp<8.000000e+00> : !cir.double, #cir.fp<7.000000e+00> : !cir.double]> : !cir.array<!cir.double x 3>
34
typedef struct {
45
int a;
56
long b;
@@ -29,14 +30,14 @@ void buz(int x) {
2930
void foo() {
3031
double bar[] = {9,8,7};
3132
}
33+
// CHECK-LABEL: @foo
34+
// CHECK: %[[DST:.*]] = cir.alloca !cir.array<!cir.double x 3>, !cir.ptr<!cir.array<!cir.double x 3>>, ["bar"]
35+
// CHECK: %[[SRC:.*]] = cir.get_global @__const.foo.bar : !cir.ptr<!cir.array<!cir.double x 3>>
36+
// CHECK: cir.copy %[[SRC]] to %[[DST]] : !cir.ptr<!cir.array<!cir.double x 3>>
3237

33-
// CHECK: %0 = cir.alloca !cir.array<!cir.double x 3>, !cir.ptr<!cir.array<!cir.double x 3>>, ["bar"] {alignment = 16 : i64}
34-
// CHECK-NEXT: %1 = cir.const #cir.const_array<[#cir.fp<9.000000e+00> : !cir.double, #cir.fp<8.000000e+00> : !cir.double, #cir.fp<7.000000e+00> : !cir.double]> : !cir.array<!cir.double x 3>
35-
// CHECK-NEXT: cir.store %1, %0 : !cir.array<!cir.double x 3>, !cir.ptr<!cir.array<!cir.double x 3>>
3638
void bar(int a, int b, int c) {
3739
int arr[] = {a,b,c};
3840
}
39-
4041
// CHECK: cir.func @bar
4142
// CHECK: [[ARR:%.*]] = cir.alloca !cir.array<!s32i x 3>, !cir.ptr<!cir.array<!s32i x 3>>, ["arr", init] {alignment = 4 : i64}
4243
// CHECK-NEXT: cir.store %arg0, [[A:%.*]] : !s32i, !cir.ptr<!s32i>
@@ -56,7 +57,6 @@ void bar(int a, int b, int c) {
5657
void zero_init(int x) {
5758
int arr[3] = {x};
5859
}
59-
6060
// CHECK: cir.func @zero_init
6161
// CHECK: [[VAR_ALLOC:%.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init] {alignment = 4 : i64}
6262
// CHECK: %1 = cir.alloca !cir.array<!s32i x 3>, !cir.ptr<!cir.array<!s32i x 3>>, ["arr", init] {alignment = 4 : i64}

clang/test/CIR/CodeGen/const-array.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ void foo() {
1212
int a[10] = {1};
1313
}
1414

15-
// CHECK: cir.func {{.*@foo}}
16-
// CHECK: %0 = cir.alloca !cir.array<!s32i x 10>, !cir.ptr<!cir.array<!s32i x 10>>, ["a"] {alignment = 16 : i64}
17-
// CHECK: %1 = cir.const #cir.const_array<[#cir.int<1> : !s32i], trailing_zeros> : !cir.array<!s32i x 10>
18-
// CHECK: cir.store %1, %0 : !cir.array<!s32i x 10>, !cir.ptr<!cir.array<!s32i x 10>>
15+
// CHECK-LABEL: @foo()
16+
// CHECK: %[[ADDR:.*]] = cir.alloca !cir.array<!s32i x 10>, !cir.ptr<!cir.array<!s32i x 10>>, ["a"]
17+
// CHECK: %[[SRC:.*]] = cir.get_global @__const.foo.a : !cir.ptr<!cir.array<!s32i x 10>>
18+
// CHECK: cir.copy %[[SRC]] to %[[ADDR]] : !cir.ptr<!cir.array<!s32i x 10>>

clang/test/CIR/Lowering/array-init.c

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll
22
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM
33

4-
// LLVM: charInit1.ar = internal global [4 x [4 x i8]] {{.*}}4 x i8] c"aa\00\00", [4 x i8] c"aa\00\00", [4 x i8] c"aa\00\00", [4 x i8] c"aa\00\00"], align 16
4+
// LLVM-DAG: @__const.charInit3.arr
5+
// LLVM-DAG: @__const.charInit2.arr
6+
// LLVM-DAG: @charInit1.ar = internal global [4 x [4 x i8]] {{.*}}4 x i8] c"aa\00\00", [4 x i8] c"aa\00\00", [4 x i8] c"aa\00\00", [4 x i8] c"aa\00\00"], align 16
57
char charInit1() {
68
static char ar[][4] = {"aa", "aa", "aa", "aa"};
79
return ar[0][0];
@@ -14,14 +16,16 @@ void zeroInit() {
1416
int a[3] = {0, 0, 0};
1517
}
1618

17-
// LLVM: %1 = alloca [4 x [1 x i8]], i64 1, align 1
18-
// LLVM: store [4 x [1 x i8]] {{.*}}1 x i8] c"a", [1 x i8] c"b", [1 x i8] c"c", [1 x i8] c"d"], ptr %1, align 1
19+
// LLVM: %[[PTR:.*]] = alloca [4 x [1 x i8]], i64 1, align 1
20+
// FIXME: OG uses @llvm.memcpy.p0.p0.i64
21+
// LLVM: void @llvm.memcpy.p0.p0.i32(ptr %[[PTR]], ptr @__const.charInit2.arr, i32 4, i1 false)
1922
void charInit2() {
2023
char arr[4][1] = {"a", "b", "c", "d"};
2124
}
2225

23-
// LLVM: %1 = alloca [4 x [2 x i8]], i64 1, align 1
24-
// LLVM: store [4 x [2 x i8]] {{.*}}2 x i8] c"ab", [2 x i8] c"cd", [2 x i8] c"ef", [2 x i8] c"gh"], ptr %1, align 1
26+
// LLVM: %[[PTR:.*]] = alloca [4 x [2 x i8]], i64 1, align 1
27+
// FIXME: OG uses @llvm.memcpy.p0.p0.i64
28+
// LLVM: call void @llvm.memcpy.p0.p0.i32(ptr %[[PTR]], ptr @__const.charInit3.arr, i32 8, i1 false), !dbg !16
2529
void charInit3() {
2630
char arr[4][2] = {"ab", "cd", "ef", "gh"};
2731
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -fclangir -emit-cir -mmlir --mlir-print-ir-after=cir-lowering-prepare %s -o %t2.cir 2>&1 | FileCheck -check-prefix=AFTER %s
2+
// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -fclangir -emit-llvm %s -o %t.ll
3+
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM
4+
5+
void foo() {
6+
char s1[] = "Hello";
7+
}
8+
// AFTER-DAG: cir.global "private" constant cir_private @__const._Z3foov.s1 = #cir.const_array<"Hello\00" : !cir.array<!s8i x 6>> : !cir.array<!s8i x 6>
9+
// AFTER: @_Z3foov
10+
// AFTER: %[[S1:.*]] = cir.alloca !cir.array<!s8i x 6>, !cir.ptr<!cir.array<!s8i x 6>>, ["s1"]
11+
// AFTER: %[[HELLO:.*]] = cir.get_global @__const._Z3foov.s1 : !cir.ptr<!cir.array<!s8i x 6>>
12+
// AFTER: cir.copy %[[HELLO]] to %[[S1]] : !cir.ptr<!cir.array<!s8i x 6>>
13+
// AFTER: cir.return
14+
// AFTER: }
15+
16+
// LLVM: @__const._Z3foov.s1 = private constant [6 x i8] c"Hello\00"
17+
// LLVM: @_Z3foov()
18+
// LLVM: %[[S1:.*]] = alloca [6 x i8], i64 1, align 1
19+
// FIXME: LLVM OG uses @llvm.memcpy.p0.p0.i64
20+
// LLVM: call void @llvm.memcpy.p0.p0.i32(ptr %[[S1]], ptr @__const._Z3foov.s1, i32 6, i1 false)
21+
// LLVM: ret void

0 commit comments

Comments
 (0)