Skip to content

Commit 4820d60

Browse files
authored
[CIR] Backport VecTernaryOp folder (#1705)
Backporting the VecTernaryOp folder
1 parent ffa62b5 commit 4820d60

File tree

5 files changed

+63
-10
lines changed

5 files changed

+63
-10
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3225,7 +3225,7 @@ def VecCmpOp : CIR_Op<"vec.cmp", [Pure, SameTypeOperands]> {
32253225
//===----------------------------------------------------------------------===//
32263226

32273227
def VecTernaryOp : CIR_Op<"vec.ternary",
3228-
[Pure, AllTypesMatch<["result", "vec1", "vec2"]>]> {
3228+
[Pure, AllTypesMatch<["result", "lhs", "rhs"]>]> {
32293229
let summary = "The `cond ? a : b` ternary operator for vector types";
32303230
let description = [{
32313231
The `cir.vec.ternary` operation represents the C/C++ ternary operator,
@@ -3244,16 +3244,18 @@ def VecTernaryOp : CIR_Op<"vec.ternary",
32443244

32453245
let arguments = (ins
32463246
CIR_VectorOfIntType:$cond,
3247-
CIR_VectorType:$vec1,
3248-
CIR_VectorType:$vec2
3247+
CIR_VectorType:$lhs,
3248+
CIR_VectorType:$rhs
32493249
);
32503250

32513251
let results = (outs CIR_VectorType:$result);
32523252
let assemblyFormat = [{
3253-
`(` $cond `,` $vec1 `,` $vec2 `)` `:` qualified(type($cond)) `,`
3254-
qualified(type($vec1)) attr-dict
3253+
`(` $cond `,` $lhs `,` $rhs `)` `:` qualified(type($cond)) `,`
3254+
qualified(type($lhs)) attr-dict
32553255
}];
3256+
32563257
let hasVerifier = 1;
3258+
let hasFolder = 1;
32573259
}
32583260

32593261
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,16 +1046,48 @@ LogicalResult cir::VecCreateOp::verify() {
10461046
// VecTernaryOp
10471047
//===----------------------------------------------------------------------===//
10481048

1049+
OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
1050+
mlir::Attribute cond = adaptor.getCond();
1051+
mlir::Attribute lhs = adaptor.getLhs();
1052+
mlir::Attribute rhs = adaptor.getRhs();
1053+
1054+
if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
1055+
!mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
1056+
!mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
1057+
return {};
1058+
auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
1059+
auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
1060+
auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
1061+
1062+
mlir::ArrayAttr condElts = condVec.getElts();
1063+
1064+
SmallVector<mlir::Attribute, 16> elements;
1065+
elements.reserve(condElts.size());
1066+
1067+
for (const auto &[idx, condAttr] :
1068+
llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
1069+
if (condAttr.getSInt()) {
1070+
elements.push_back(lhsVec.getElts()[idx]);
1071+
} else {
1072+
elements.push_back(rhsVec.getElts()[idx]);
1073+
}
1074+
}
1075+
1076+
cir::VectorType vecTy = getLhs().getType();
1077+
return cir::ConstVectorAttr::get(
1078+
vecTy, mlir::ArrayAttr::get(getContext(), elements));
1079+
}
1080+
10491081
LogicalResult cir::VecTernaryOp::verify() {
10501082
// Verify that the condition operand has the same number of elements as the
10511083
// other operands. (The automatic verification already checked that all
10521084
// operands are vector types and that the second and third operands are the
10531085
// same type.)
10541086
if (mlir::cast<cir::VectorType>(getCond().getType()).getSize() !=
1055-
getVec1().getType().getSize()) {
1087+
getLhs().getType().getSize()) {
10561088
return emitOpError() << ": the number of elements in "
1057-
<< getCond().getType() << " and "
1058-
<< getVec1().getType() << " don't match";
1089+
<< getCond().getType() << " and " << getLhs().getType()
1090+
<< " don't match";
10591091
}
10601092
return success();
10611093
}

clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ void CIRCanonicalizePass::runOnOperation() {
180180
// applyOpPatternsGreedily.
181181
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp, SelectOp,
182182
ComplexCreateOp, ComplexRealOp, ComplexImagOp, CallOp, VecCreateOp,
183-
VecExtractOp>(op))
183+
VecExtractOp, VecTernaryOp>(op))
184184
ops.push_back(op);
185185
});
186186

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2008,7 +2008,7 @@ mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite(
20082008
op.getCond().getLoc(),
20092009
typeConverter->convertType(op.getCond().getType())));
20102010
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
2011-
op, bitVec, adaptor.getVec1(), adaptor.getVec2());
2011+
op, bitVec, adaptor.getLhs(), adaptor.getRhs());
20122012
return mlir::success();
20132013
}
20142014

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: cir-opt %s -cir-canonicalize -o - | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
cir.func @vector_ternary_fold_test() -> !cir.vector<!s32i x 4> {
7+
%cond = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<!s32i x 4>
8+
%lhs = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<!s32i x 4>
9+
%rhs = cir.const #cir.const_vector<[#cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<!s32i x 4>
10+
%res = cir.vec.ternary(%cond, %lhs, %rhs) : !cir.vector<!s32i x 4>, !cir.vector<!s32i x 4>
11+
cir.return %res : !cir.vector<!s32i x 4>
12+
}
13+
14+
// [1, 0, 1, 0] ? [1, 2, 3, 4] : [5, 6, 7, 8] Will be fold to [1, 6, 3, 8]
15+
// CHECK: cir.func @vector_ternary_fold_test() -> !cir.vector<!s32i x 4> {
16+
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<6> : !s32i, #cir.int<3> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<!s32i x 4>
17+
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<!s32i x 4>
18+
}
19+

0 commit comments

Comments
 (0)