Skip to content

Commit c76c138

Browse files
authored
[CIR][CIRGen][Builtin][Neon] Lower neon_vshll_n (#1010)
1 parent d7de21f commit c76c138

File tree

2 files changed

+113
-53
lines changed

2 files changed

+113
-53
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2241,6 +2241,19 @@ static mlir::Value buildNeonShiftVector(CIRGenBuilderTy &builder,
22412241
return builder.create<mlir::cir::ConstantOp>(loc, vecTy, constVecAttr);
22422242
}
22432243

2244+
/// Build ShiftOp of vector type whose shift amount is a vector built
2245+
/// from a constant integer using `buildNeonShiftVector` function
2246+
static mlir::Value buildCommonNeonShift(CIRGenBuilderTy &builder,
2247+
mlir::Location loc,
2248+
mlir::cir::VectorType resTy,
2249+
mlir::Value shifTgt,
2250+
mlir::Value shiftAmt, bool shiftLeft,
2251+
bool negAmt = false) {
2252+
shiftAmt = buildNeonShiftVector(builder, shiftAmt, resTy, loc, negAmt);
2253+
return builder.create<mlir::cir::ShiftOp>(
2254+
loc, resTy, builder.createBitcast(shifTgt, resTy), shiftAmt, shiftLeft);
2255+
}
2256+
22442257
mlir::Value CIRGenFunction::buildCommonNeonBuiltinExpr(
22452258
unsigned builtinID, unsigned llvmIntrinsic, unsigned altLLVMIntrinsic,
22462259
const char *nameHint, unsigned modifier, const CallExpr *e,
@@ -2328,9 +2341,18 @@ mlir::Value CIRGenFunction::buildCommonNeonBuiltinExpr(
23282341
case NEON::BI__builtin_neon_vshl_n_v:
23292342
case NEON::BI__builtin_neon_vshlq_n_v: {
23302343
mlir::Location loc = getLoc(e->getExprLoc());
2331-
ops[1] = buildNeonShiftVector(builder, ops[1], vTy, loc, false);
2332-
return builder.create<mlir::cir::ShiftOp>(
2333-
loc, vTy, builder.createBitcast(ops[0], vTy), ops[1], true);
2344+
return buildCommonNeonShift(builder, loc, vTy, ops[0], ops[1], true);
2345+
}
2346+
case NEON::BI__builtin_neon_vshll_n_v: {
2347+
mlir::Location loc = getLoc(e->getExprLoc());
2348+
mlir::cir::VectorType srcTy =
2349+
builder.getExtendedOrTruncatedElementVectorType(
2350+
vTy, false /* truncate */,
2351+
mlir::cast<mlir::cir::IntType>(vTy.getEltType()).isSigned());
2352+
ops[0] = builder.createBitcast(ops[0], srcTy);
2353+
// The following cast will be lowered to SExt or ZExt in LLVM.
2354+
ops[0] = builder.createIntCast(ops[0], vTy);
2355+
return buildCommonNeonShift(builder, loc, vTy, ops[0], ops[1], true);
23342356
}
23352357
}
23362358

clang/test/CIR/CodeGen/AArch64/neon.c

Lines changed: 88 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6587,61 +6587,99 @@ uint32x2_t test_vqrshrun_n_s64(int64x2_t a) {
65876587
// return vqrshrn_high_n_u64(a, b, 19);
65886588
// }
65896589

6590-
// NYI-LABEL: @test_vshll_n_s8(
6591-
// NYI: [[TMP0:%.*]] = sext <8 x i8> %a to <8 x i16>
6592-
// NYI: [[VSHLL_N:%.*]] = shl <8 x i16> [[TMP0]], <i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3>
6593-
// NYI: ret <8 x i16> [[VSHLL_N]]
6594-
// int16x8_t test_vshll_n_s8(int8x8_t a) {
6595-
// return vshll_n_s8(a, 3);
6596-
// }
6590+
int16x8_t test_vshll_n_s8(int8x8_t a) {
6591+
return vshll_n_s8(a, 3);
6592+
6593+
// CIR-LABEL: vshll_n_s8
6594+
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!s8i x 8>), !cir.vector<!s16i x 8>
6595+
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<3> : !s16i, #cir.int<3> : !s16i, #cir.int<3> : !s16i, #cir.int<3> : !s16i,
6596+
// CIR-SAME: #cir.int<3> : !s16i, #cir.int<3> : !s16i, #cir.int<3> : !s16i, #cir.int<3> : !s16i]> : !cir.vector<!s16i x 8>
6597+
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!s16i x 8>, [[SHIFT_AMT]] : !cir.vector<!s16i x 8>) -> !cir.vector<!s16i x 8>
6598+
6599+
// LLVM: {{.*}}@test_vshll_n_s8(<8 x i8>{{.*}}[[A:%.*]])
6600+
// LLVM: [[TMP0:%.*]] = sext <8 x i8> [[A]] to <8 x i16>
6601+
// LLVM: [[VSHLL_N:%.*]] = shl <8 x i16> [[TMP0]], <i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3>
6602+
// LLVM: ret <8 x i16> [[VSHLL_N]]
6603+
}
65976604

6598-
// NYI-LABEL: @test_vshll_n_s16(
6599-
// NYI: [[TMP0:%.*]] = bitcast <4 x i16> %a to <8 x i8>
6600-
// NYI: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x i16>
6601-
// NYI: [[TMP2:%.*]] = sext <4 x i16> [[TMP1]] to <4 x i32>
6602-
// NYI: [[VSHLL_N:%.*]] = shl <4 x i32> [[TMP2]], <i32 9, i32 9, i32 9, i32 9>
6603-
// NYI: ret <4 x i32> [[VSHLL_N]]
6604-
// int32x4_t test_vshll_n_s16(int16x4_t a) {
6605-
// return vshll_n_s16(a, 9);
6606-
// }
6605+
int32x4_t test_vshll_n_s16(int16x4_t a) {
6606+
return vshll_n_s16(a, 9);
66076607

6608-
// NYI-LABEL: @test_vshll_n_s32(
6609-
// NYI: [[TMP0:%.*]] = bitcast <2 x i32> %a to <8 x i8>
6610-
// NYI: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <2 x i32>
6611-
// NYI: [[TMP2:%.*]] = sext <2 x i32> [[TMP1]] to <2 x i64>
6612-
// NYI: [[VSHLL_N:%.*]] = shl <2 x i64> [[TMP2]], <i64 19, i64 19>
6613-
// NYI: ret <2 x i64> [[VSHLL_N]]
6614-
// int64x2_t test_vshll_n_s32(int32x2_t a) {
6615-
// return vshll_n_s32(a, 19);
6616-
// }
6608+
// CIR-LABEL: vshll_n_s16
6609+
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!s16i x 4>), !cir.vector<!s32i x 4>
6610+
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<9> : !s32i, #cir.int<9> : !s32i, #cir.int<9> :
6611+
// CIR-SAME: !s32i, #cir.int<9> : !s32i]> : !cir.vector<!s32i x 4>
6612+
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!s32i x 4>, [[SHIFT_AMT]] : !cir.vector<!s32i x 4>) -> !cir.vector<!s32i x 4>
66176613

6618-
// NYI-LABEL: @test_vshll_n_u8(
6619-
// NYI: [[TMP0:%.*]] = zext <8 x i8> %a to <8 x i16>
6620-
// NYI: [[VSHLL_N:%.*]] = shl <8 x i16> [[TMP0]], <i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3>
6621-
// NYI: ret <8 x i16> [[VSHLL_N]]
6622-
// uint16x8_t test_vshll_n_u8(uint8x8_t a) {
6623-
// return vshll_n_u8(a, 3);
6624-
// }
6614+
// LLVM: {{.*}}@test_vshll_n_s16(<4 x i16>{{.*}}[[A:%.*]])
6615+
// LLVM: [[TMP0:%.*]] = bitcast <4 x i16> [[A]] to <8 x i8>
6616+
// LLVM: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x i16>
6617+
// LLVM: [[TMP2:%.*]] = sext <4 x i16> [[TMP1]] to <4 x i32>
6618+
// LLVM: [[VSHLL_N:%.*]] = shl <4 x i32> [[TMP2]], <i32 9, i32 9, i32 9, i32 9>
6619+
// LLVM: ret <4 x i32> [[VSHLL_N]]
6620+
}
66256621

6626-
// NYI-LABEL: @test_vshll_n_u16(
6627-
// NYI: [[TMP0:%.*]] = bitcast <4 x i16> %a to <8 x i8>
6628-
// NYI: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x i16>
6629-
// NYI: [[TMP2:%.*]] = zext <4 x i16> [[TMP1]] to <4 x i32>
6630-
// NYI: [[VSHLL_N:%.*]] = shl <4 x i32> [[TMP2]], <i32 9, i32 9, i32 9, i32 9>
6631-
// NYI: ret <4 x i32> [[VSHLL_N]]
6632-
// uint32x4_t test_vshll_n_u16(uint16x4_t a) {
6633-
// return vshll_n_u16(a, 9);
6634-
// }
6622+
int64x2_t test_vshll_n_s32(int32x2_t a) {
6623+
return vshll_n_s32(a, 19);
66356624

6636-
// NYI-LABEL: @test_vshll_n_u32(
6637-
// NYI: [[TMP0:%.*]] = bitcast <2 x i32> %a to <8 x i8>
6638-
// NYI: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <2 x i32>
6639-
// NYI: [[TMP2:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
6640-
// NYI: [[VSHLL_N:%.*]] = shl <2 x i64> [[TMP2]], <i64 19, i64 19>
6641-
// NYI: ret <2 x i64> [[VSHLL_N]]
6642-
// uint64x2_t test_vshll_n_u32(uint32x2_t a) {
6643-
// return vshll_n_u32(a, 19);
6644-
// }
6625+
// CIR-LABEL: vshll_n_s32
6626+
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!s32i x 2>), !cir.vector<!s64i x 2>
6627+
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<19> : !s64i, #cir.int<19> : !s64i]> : !cir.vector<!s64i x 2>
6628+
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!s64i x 2>, [[SHIFT_AMT]] : !cir.vector<!s64i x 2>)
6629+
6630+
// LLVM: {{.*}}@test_vshll_n_s32(<2 x i32>{{.*}}[[A:%.*]])
6631+
// LLVM: [[TMP0:%.*]] = bitcast <2 x i32> [[A]] to <8 x i8>
6632+
// LLVM: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <2 x i32>
6633+
// LLVM: [[TMP2:%.*]] = sext <2 x i32> [[TMP1]] to <2 x i64>
6634+
// LLVM: [[VSHLL_N:%.*]] = shl <2 x i64> [[TMP2]], <i64 19, i64 19>
6635+
// LLVM: ret <2 x i64> [[VSHLL_N]]
6636+
}
6637+
6638+
uint16x8_t test_vshll_n_u8(uint8x8_t a) {
6639+
return vshll_n_u8(a, 3);
6640+
6641+
// CIR-LABEL: vshll_n_u8
6642+
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!u8i x 8>), !cir.vector<!u16i x 8>
6643+
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<3> : !u16i, #cir.int<3> : !u16i, #cir.int<3> : !u16i, #cir.int<3> : !u16i,
6644+
// CIR-SAME: #cir.int<3> : !u16i, #cir.int<3> : !u16i, #cir.int<3> : !u16i, #cir.int<3> : !u16i]> : !cir.vector<!u16i x 8>
6645+
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!u16i x 8>, [[SHIFT_AMT]] : !cir.vector<!u16i x 8>)
6646+
6647+
// LLVM: {{.*}}@test_vshll_n_u8(<8 x i8>{{.*}}[[A:%.*]])
6648+
// LLVM: [[TMP0:%.*]] = zext <8 x i8> [[A]] to <8 x i16>
6649+
// LLVM: [[VSHLL_N:%.*]] = shl <8 x i16> [[TMP0]], <i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3>
6650+
}
6651+
6652+
uint32x4_t test_vshll_n_u16(uint16x4_t a) {
6653+
return vshll_n_u16(a, 9);
6654+
6655+
// CIR-LABEL: vshll_n_u16
6656+
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!u16i x 4>), !cir.vector<!u32i x 4>
6657+
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<9> : !u32i, #cir.int<9> : !u32i,
6658+
// CIR-SAME: #cir.int<9> : !u32i, #cir.int<9> : !u32i]> : !cir.vector<!u32i x 4>
6659+
6660+
// LLVM: {{.*}}@test_vshll_n_u16(<4 x i16>{{.*}}[[A:%.*]])
6661+
// LLVM: [[TMP0:%.*]] = bitcast <4 x i16> [[A]] to <8 x i8>
6662+
// LLVM: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x i16>
6663+
// LLVM: [[TMP2:%.*]] = zext <4 x i16> [[TMP1]] to <4 x i32>
6664+
// LLVM: [[VSHLL_N:%.*]] = shl <4 x i32> [[TMP2]], <i32 9, i32 9, i32 9, i32 9>
6665+
// LLVM: ret <4 x i32> [[VSHLL_N]]
6666+
}
6667+
6668+
uint64x2_t test_vshll_n_u32(uint32x2_t a) {
6669+
return vshll_n_u32(a, 19);
6670+
6671+
// CIR-LABEL: vshll_n_u32
6672+
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!u32i x 2>), !cir.vector<!u64i x 2>
6673+
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<19> : !u64i, #cir.int<19> : !u64i]> : !cir.vector<!u64i x 2>
6674+
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!u64i x 2>, [[SHIFT_AMT]] : !cir.vector<!u64i x 2>)
6675+
6676+
// LLVM: {{.*}}@test_vshll_n_u32(<2 x i32>{{.*}}[[A:%.*]])
6677+
// LLVM: [[TMP0:%.*]] = bitcast <2 x i32> [[A]] to <8 x i8>
6678+
// LLVM: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <2 x i32>
6679+
// LLVM: [[TMP2:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
6680+
// LLVM: [[VSHLL_N:%.*]] = shl <2 x i64> [[TMP2]], <i64 19, i64 19>
6681+
// LLVM: ret <2 x i64> [[VSHLL_N]]
6682+
}
66456683

66466684
// NYI-LABEL: @test_vshll_high_n_s8(
66476685
// NYI: [[SHUFFLE_I:%.*]] = shufflevector <16 x i8> %a, <16 x i8> %a, <8 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>

0 commit comments

Comments
 (0)