Skip to content

Commit dd9353b

Browse files
authored
[CIR][ThroughMLIR] Lower CIR IV load with SCF IV move operation (#729)
Previously, when lowering induction variable in forOp, we removed the IV load and replaced the users with SCF.IV. The CIR IV users might still CIR operations during lowering forOp. It caused the issue that CIR operation contained SCF.IV as operand which is MLIR integer type instead CIR type. This comment lower CIR load IV_ADDR with ARITH addi SCF.IV, 0 So SCF.IV can be propagated by OpAdaptor when lowering individual IV users. This simplifies the lowering and fixes the issue. The redundant arith.addi can be removed by later MLIR passes.
1 parent c849210 commit dd9353b

File tree

3 files changed

+111
-257
lines changed

3 files changed

+111
-257
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -204,19 +204,6 @@ void SCFLoop::analysis() {
204204
assert(upperBound && "can't find loop upper bound");
205205
}
206206

207-
// Return true if op operation is in the loop body.
208-
static bool isInLoopBody(mlir::Operation *op) {
209-
mlir::Operation *parentOp = op->getParentOp();
210-
if (!parentOp)
211-
return false;
212-
if (isa<mlir::scf::ForOp>(parentOp))
213-
return true;
214-
auto forOp = dyn_cast<mlir::cir::ForOp>(parentOp);
215-
if (forOp && (&forOp.getBody() == op->getParentRegion()))
216-
return true;
217-
return false;
218-
}
219-
220207
void SCFLoop::transferToSCFForOp() {
221208
auto ub = getUpperBound();
222209
auto lb = getLowerBound();
@@ -236,12 +223,13 @@ void SCFLoop::transferToSCFForOp() {
236223
"Not support lowering loop with break, continue or if yet");
237224
// Replace the IV usage to scf loop induction variable.
238225
if (isIVLoad(op, IVAddr)) {
239-
auto newIV = scfForOp.getInductionVar();
240-
op->getResult(0).replaceAllUsesWith(newIV);
241-
// Only erase the IV load in the loop body because all the operations
242-
// in loop step and condition regions will be erased.
243-
if (isInLoopBody(op))
244-
rewriter->eraseOp(op);
226+
// Replace CIR IV load with arith.addi scf.IV, 0.
227+
// The replacement makes the SCF IV can be automatically propogated
228+
// by OpAdaptor for individual IV user lowering.
229+
// The redundant arith.addi can be removed by later MLIR passes.
230+
rewriter->setInsertionPoint(op);
231+
auto newIV = plusConstant(scfForOp.getInductionVar(), loc, 0);
232+
rewriter->replaceOp(op, newIV.getDefiningOp());
245233
}
246234
return mlir::WalkResult::advance();
247235
});
@@ -318,4 +306,4 @@ void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
318306
converter, patterns.getContext());
319307
}
320308

321-
} // namespace cir
309+
} // namespace cir

clang/test/CIR/Lowering/ThroughMLIR/for.cir

Lines changed: 0 additions & 237 deletions
This file was deleted.
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
int a[101], b[101];
5+
6+
void constantLoopBound() {
7+
for (int i = 0; i < 100; ++i)
8+
a[i] = 3;
9+
}
10+
// CHECK-LABEL: func.func @_Z17constantLoopBoundv() {
11+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
12+
// CHECK: %[[C100:.*]] = arith.constant 100 : i32
13+
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
14+
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C100]] step %[[C1]] : i32 {
15+
// CHECK: %[[C3:.*]] = arith.constant 3 : i32
16+
// CHECK: %[[BASE:.*]] = memref.get_global @a : memref<101xi32>
17+
// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32
18+
// CHECK: %[[IV:.*]] = arith.addi %[[I]], %[[C0_i32]] : i32
19+
// CHECK: %[[INDEX:.*]] = arith.index_cast %[[IV]] : i32 to index
20+
// CHECK: memref.store %[[C3]], %[[BASE]][%[[INDEX]]] : memref<101xi32>
21+
// CHECK: }
22+
23+
void constantLoopBound_LE() {
24+
for (int i = 0; i <= 100; ++i)
25+
a[i] = 3;
26+
}
27+
// CHECK-LABEL: func.func @_Z20constantLoopBound_LEv() {
28+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
29+
// CHECK: %[[C100:.*]] = arith.constant 100 : i32
30+
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
31+
// CHECK: %[[C101:.*]] = arith.addi %c100_i32, %c1_i32 : i32
32+
// CHECK: %[[C1_STEP:.*]] = arith.constant 1 : i32
33+
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C101]] step %[[C1_STEP]] : i32 {
34+
// CHECK: %[[C3:.*]] = arith.constant 3 : i32
35+
// CHECK: %[[BASE:.*]] = memref.get_global @a : memref<101xi32>
36+
// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32
37+
// CHECK: %[[IV:.*]] = arith.addi %[[I]], %[[C0_i32]] : i32
38+
// CHECK: %[[INDEX:.*]] = arith.index_cast %[[IV]] : i32 to index
39+
// CHECK: memref.store %[[C3]], %[[BASE]][%[[INDEX]]] : memref<101xi32>
40+
// CHECK: }
41+
42+
void variableLoopBound(int l, int u) {
43+
for (int i = l; i < u; ++i)
44+
a[i] = 3;
45+
}
46+
// CHECK-LABEL: func.func @_Z17variableLoopBoundii
47+
// CHECK: memref.store %arg0, %alloca[] : memref<i32>
48+
// CHECK: memref.store %arg1, %alloca_0[] : memref<i32>
49+
// CHECK: %[[LOWER:.*]] = memref.load %alloca[] : memref<i32>
50+
// CHECK: %[[UPPER:.*]] = memref.load %alloca_0[] : memref<i32>
51+
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
52+
// CHECK: scf.for %[[I:.*]] = %[[LOWER]] to %[[UPPER]] step %[[C1]] : i32 {
53+
// CHECK: %[[C3:.*]] = arith.constant 3 : i32
54+
// CHECK: %[[BASE:.*]] = memref.get_global @a : memref<101xi32>
55+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
56+
// CHECK: %[[IV:.*]] = arith.addi %[[I]], %[[C0]] : i32
57+
// CHECK: %[[INDEX:.*]] = arith.index_cast %[[IV]] : i32 to index
58+
// CHECK: memref.store %[[C3]], %[[BASE]][%[[INDEX]]] : memref<101xi32>
59+
// CHECK: }
60+
61+
void ariableLoopBound_LE(int l, int u) {
62+
for (int i = l; i <= u; i+=4)
63+
a[i] = 3;
64+
}
65+
// CHECK-LABEL: func.func @_Z19ariableLoopBound_LEii
66+
// CHECK: memref.store %arg0, %alloca[] : memref<i32>
67+
// CHECK: memref.store %arg1, %alloca_0[] : memref<i32>
68+
// CHECK: %[[LOWER:.*]] = memref.load %alloca[] : memref<i32>
69+
// CHECK: %[[UPPER_DEC_1:.*]] = memref.load %alloca_0[] : memref<i32>
70+
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
71+
// CHECK: %[[UPPER:.*]] = arith.addi %[[UPPER_DEC_1]], %[[C1]] : i32
72+
// CHECK: %[[C4:.*]] = arith.constant 4 : i32
73+
// CHECK: scf.for %[[I:.*]] = %[[LOWER]] to %[[UPPER]] step %[[C4]] : i32 {
74+
// CHECK: %[[C3:.*]] = arith.constant 3 : i32
75+
// CHECK: %[[BASE:.*]] = memref.get_global @a : memref<101xi32>
76+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
77+
// CHECK: %[[IV:.*]] = arith.addi %[[I]], %[[C0]] : i32
78+
// CHECK: %[[INDEX:.*]] = arith.index_cast %[[IV]] : i32 to index
79+
// CHECK: memref.store %[[C3]], %[[BASE]][%[[INDEX]]] : memref<101xi32>
80+
// CHECK: }
81+
82+
void incArray() {
83+
for (int i = 0; i < 100; ++i)
84+
a[i] += b[i];
85+
}
86+
// CHECK-LABEL: func.func @_Z8incArrayv() {
87+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
88+
// CHECK: %[[C100:.*]] = arith.constant 100 : i32
89+
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
90+
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C100]] step %[[C1]] : i32 {
91+
// CHECK: %[[B:.*]] = memref.get_global @b : memref<101xi32>
92+
// CHECK: %[[C0_2:.*]] = arith.constant 0 : i32
93+
// CHECK: %[[IV2:.*]] = arith.addi %[[I]], %[[C0_2]] : i32
94+
// CHECK: %[[INDEX_2:.*]] = arith.index_cast %[[IV2]] : i32 to index
95+
// CHECK: %[[B_VALUE:.*]] = memref.load %[[B]][%[[INDEX_2]]] : memref<101xi32>
96+
// CHECK: %[[A:.*]] = memref.get_global @a : memref<101xi32>
97+
// CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
98+
// CHECK: %[[IV1:.*]] = arith.addi %[[I]], %[[C0_1]] : i32
99+
// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[IV1]] : i32 to index
100+
// CHECK: %[[A_VALUE:.*]] = memref.load %[[A]][%[[INDEX_1]]] : memref<101xi32>
101+
// CHECK: %[[SUM:.*]] = arith.addi %[[A_VALUE]], %[[B_VALUE]] : i32
102+
// CHECK: memref.store %[[SUM]], %[[A]][%[[INDEX_1]]] : memref<101xi32>
103+
// CHECK: }

0 commit comments

Comments
 (0)