Skip to content

Commit cdc55eb

Browse files
Lancernlanza
authored andcommitted
[CIR] Add support for unary complex operations (#750)
This PR adds support for unary operations on complex numbers, namely plus(+), minus(-), and conjugate(~). This PR also adds support for the `__builtin_conj` builtin function which computes the conjugate of the input.
1 parent 18c1f53 commit cdc55eb

File tree

4 files changed

+243
-17
lines changed

4 files changed

+243
-17
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,8 +756,12 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
756756
case Builtin::BI__builtin_conjl:
757757
case Builtin::BIconj:
758758
case Builtin::BIconjf:
759-
case Builtin::BIconjl:
760-
llvm_unreachable("NYI");
759+
case Builtin::BIconjl: {
760+
mlir::Value ComplexVal = buildComplexExpr(E->getArg(0));
761+
mlir::Value Conj = builder.createUnaryOp(
762+
getLoc(E->getExprLoc()), mlir::cir::UnaryOpKind::Not, ComplexVal);
763+
return RValue::getComplex(Conj);
764+
}
761765

762766
case Builtin::BI__builtin___CFStringMakeConstantString:
763767
case Builtin::BI__builtin___NSStringMakeConstantString:

clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -152,20 +152,12 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
152152
mlir::Value VisitUnaryDeref(const Expr *E) { llvm_unreachable("NYI"); }
153153

154154
mlir::Value VisitUnaryPlus(const UnaryOperator *E,
155-
QualType PromotionType = QualType()) {
156-
llvm_unreachable("NYI");
157-
}
158-
mlir::Value VisitPlus(const UnaryOperator *E, QualType PromotionType) {
159-
llvm_unreachable("NYI");
160-
}
155+
QualType PromotionType = QualType());
156+
mlir::Value VisitPlus(const UnaryOperator *E, QualType PromotionType);
161157
mlir::Value VisitUnaryMinus(const UnaryOperator *E,
162-
QualType PromotionType = QualType()) {
163-
llvm_unreachable("NYI");
164-
}
165-
mlir::Value VisitMinus(const UnaryOperator *E, QualType PromotionType) {
166-
llvm_unreachable("NYI");
167-
}
168-
mlir::Value VisitUnaryNot(const UnaryOperator *E) { llvm_unreachable("NYI"); }
158+
QualType PromotionType = QualType());
159+
mlir::Value VisitMinus(const UnaryOperator *E, QualType PromotionType);
160+
mlir::Value VisitUnaryNot(const UnaryOperator *E);
169161
// LNot,Real,Imag never return complex.
170162
mlir::Value VisitUnaryExtension(const UnaryOperator *E) {
171163
return Visit(E->getSubExpr());
@@ -495,6 +487,58 @@ mlir::Value ComplexExprEmitter::VisitCallExpr(const CallExpr *E) {
495487
return CGF.buildCallExpr(E).getComplexVal();
496488
}
497489

490+
mlir::Value ComplexExprEmitter::VisitUnaryPlus(const UnaryOperator *E,
491+
QualType PromotionType) {
492+
QualType promotionTy = PromotionType.isNull()
493+
? getPromotionType(E->getSubExpr()->getType())
494+
: PromotionType;
495+
mlir::Value result = VisitPlus(E, promotionTy);
496+
if (!promotionTy.isNull())
497+
return CGF.buildUnPromotedValue(result, E->getSubExpr()->getType());
498+
return result;
499+
}
500+
501+
mlir::Value ComplexExprEmitter::VisitPlus(const UnaryOperator *E,
502+
QualType PromotionType) {
503+
mlir::Value Op;
504+
if (!PromotionType.isNull())
505+
Op = CGF.buildPromotedComplexExpr(E->getSubExpr(), PromotionType);
506+
else
507+
Op = Visit(E->getSubExpr());
508+
509+
return Builder.createUnaryOp(CGF.getLoc(E->getExprLoc()),
510+
mlir::cir::UnaryOpKind::Plus, Op);
511+
}
512+
513+
mlir::Value ComplexExprEmitter::VisitUnaryMinus(const UnaryOperator *E,
514+
QualType PromotionType) {
515+
QualType promotionTy = PromotionType.isNull()
516+
? getPromotionType(E->getSubExpr()->getType())
517+
: PromotionType;
518+
mlir::Value result = VisitMinus(E, promotionTy);
519+
if (!promotionTy.isNull())
520+
return CGF.buildUnPromotedValue(result, E->getSubExpr()->getType());
521+
return result;
522+
}
523+
524+
mlir::Value ComplexExprEmitter::VisitMinus(const UnaryOperator *E,
525+
QualType PromotionType) {
526+
mlir::Value Op;
527+
if (!PromotionType.isNull())
528+
Op = CGF.buildPromotedComplexExpr(E->getSubExpr(), PromotionType);
529+
else
530+
Op = Visit(E->getSubExpr());
531+
532+
return Builder.createUnaryOp(CGF.getLoc(E->getExprLoc()),
533+
mlir::cir::UnaryOpKind::Minus, Op);
534+
}
535+
536+
mlir::Value ComplexExprEmitter::VisitUnaryNot(const UnaryOperator *E) {
537+
mlir::Value Op = Visit(E->getSubExpr());
538+
return Builder.createUnaryOp(CGF.getLoc(E->getExprLoc()),
539+
mlir::cir::UnaryOpKind::Not, Op);
540+
}
541+
498542
ComplexExprEmitter::BinOpInfo
499543
ComplexExprEmitter::buildBinOps(const BinaryOperator *E, QualType PromotionTy) {
500544
BinOpInfo Ops{CGF.getLoc(E->getExprLoc())};

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
7171
void runOnOperation() override;
7272

7373
void runOnOp(Operation *op);
74+
void lowerUnaryOp(UnaryOp op);
7475
void lowerBinOp(BinOp op);
7576
void lowerComplexBinOp(ComplexBinOp op);
7677
void lowerThreeWayCmpOp(CmpThreeWayOp op);
@@ -347,6 +348,50 @@ void LoweringPreparePass::lowerVAArgOp(VAArgOp op) {
347348
return;
348349
}
349350

351+
void LoweringPreparePass::lowerUnaryOp(UnaryOp op) {
352+
auto ty = op.getType();
353+
if (!mlir::isa<mlir::cir::ComplexType>(ty))
354+
return;
355+
356+
auto loc = op.getLoc();
357+
auto opKind = op.getKind();
358+
assert((opKind == mlir::cir::UnaryOpKind::Plus ||
359+
opKind == mlir::cir::UnaryOpKind::Minus ||
360+
opKind == mlir::cir::UnaryOpKind::Not) &&
361+
"invalid unary op kind on complex numbers");
362+
363+
CIRBaseBuilderTy builder(getContext());
364+
builder.setInsertionPointAfter(op);
365+
366+
auto operand = op.getInput();
367+
368+
auto operandReal = builder.createComplexReal(loc, operand);
369+
auto operandImag = builder.createComplexImag(loc, operand);
370+
371+
mlir::Value resultReal;
372+
mlir::Value resultImag;
373+
switch (opKind) {
374+
case mlir::cir::UnaryOpKind::Plus:
375+
case mlir::cir::UnaryOpKind::Minus:
376+
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
377+
resultImag = builder.createUnaryOp(loc, opKind, operandImag);
378+
break;
379+
380+
case mlir::cir::UnaryOpKind::Not:
381+
resultReal = operandReal;
382+
resultImag =
383+
builder.createUnaryOp(loc, mlir::cir::UnaryOpKind::Minus, operandImag);
384+
break;
385+
386+
default:
387+
llvm_unreachable("unsupported complex unary op kind");
388+
}
389+
390+
auto result = builder.createComplexCreate(loc, resultReal, resultImag);
391+
op.replaceAllUsesWith(result);
392+
op.erase();
393+
}
394+
350395
void LoweringPreparePass::lowerBinOp(BinOp op) {
351396
auto ty = op.getType();
352397
if (!mlir::isa<mlir::cir::ComplexType>(ty))
@@ -939,7 +984,9 @@ void LoweringPreparePass::lowerIterEndOp(IterEndOp op) {
939984
}
940985

941986
void LoweringPreparePass::runOnOp(Operation *op) {
942-
if (auto bin = dyn_cast<BinOp>(op)) {
987+
if (auto unary = dyn_cast<UnaryOp>(op)) {
988+
lowerUnaryOp(unary);
989+
} else if (auto bin = dyn_cast<BinOp>(op)) {
943990
lowerBinOp(bin);
944991
} else if (auto complexBin = dyn_cast<ComplexBinOp>(op)) {
945992
lowerComplexBinOp(complexBin);
@@ -980,7 +1027,7 @@ void LoweringPreparePass::runOnOperation() {
9801027
SmallVector<Operation *> opsToTransform;
9811028

9821029
op->walk([&](Operation *op) {
983-
if (isa<BinOp, ComplexBinOp, CmpThreeWayOp, VAArgOp, GlobalOp,
1030+
if (isa<UnaryOp, BinOp, ComplexBinOp, CmpThreeWayOp, VAArgOp, GlobalOp,
9841031
DynamicCastOp, StdFindOp, IterEndOp, IterBeginOp, ArrayCtor,
9851032
ArrayDtor, mlir::cir::FuncOp>(op))
9861033
opsToTransform.push_back(op);

clang/test/CIR/CodeGen/complex-arithmetic.c

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,3 +645,134 @@ void div_assign() {
645645
// CIRGEN-FULL: %{{.+}} = cir.complex.binop div %{{.+}}, %{{.+}} range(full) : !cir.complex<!s32i>
646646

647647
// CHECK: }
648+
649+
void unary_plus() {
650+
cd1 = +cd1;
651+
ci1 = +ci1;
652+
}
653+
654+
// CLANG: @unary_plus
655+
// CPPLANG: @_Z10unary_plusv
656+
657+
// CIRGEN: %{{.+}} = cir.unary(plus, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
658+
// CIRGEN: %{{.+}} = cir.unary(plus, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>
659+
660+
// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
661+
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
662+
// CIR-NEXT: %[[#RESR:]] = cir.unary(plus, %[[#OPR]]) : !cir.double, !cir.double
663+
// CIR-NEXT: %[[#RESI:]] = cir.unary(plus, %[[#OPI]]) : !cir.double, !cir.double
664+
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#RESR]], %[[#RESI]] : !cir.double -> !cir.complex<!cir.double>
665+
666+
// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
667+
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
668+
// CIR-NEXT: %[[#RESR:]] = cir.unary(plus, %[[#OPR]]) : !s32i, !s32i
669+
// CIR-NEXT: %[[#RESI:]] = cir.unary(plus, %[[#OPI]]) : !s32i, !s32i
670+
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#RESR]], %[[#RESI]] : !s32i -> !cir.complex<!s32i>
671+
672+
// LLVM: %[[#OPR:]] = extractvalue { double, double } %{{.+}}, 0
673+
// LLVM-NEXT: %[[#OPI:]] = extractvalue { double, double } %{{.+}}, 1
674+
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#OPR]], 0
675+
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#OPI]], 1
676+
677+
// LLVM: %[[#OPR:]] = extractvalue { i32, i32 } %{{.+}}, 0
678+
// LLVM-NEXT: %[[#OPI:]] = extractvalue { i32, i32 } %{{.+}}, 1
679+
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#OPR]], 0
680+
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#OPI]], 1
681+
682+
// CHECK: }
683+
684+
void unary_minus() {
685+
cd1 = -cd1;
686+
ci1 = -ci1;
687+
}
688+
689+
// CLANG: @unary_minus
690+
// CPPLANG: @_Z11unary_minusv
691+
692+
// CIRGEN: %{{.+}} = cir.unary(minus, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
693+
// CIRGEN: %{{.+}} = cir.unary(minus, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>
694+
695+
// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
696+
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
697+
// CIR-NEXT: %[[#RESR:]] = cir.unary(minus, %[[#OPR]]) : !cir.double, !cir.double
698+
// CIR-NEXT: %[[#RESI:]] = cir.unary(minus, %[[#OPI]]) : !cir.double, !cir.double
699+
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#RESR]], %[[#RESI]] : !cir.double -> !cir.complex<!cir.double>
700+
701+
// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
702+
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
703+
// CIR-NEXT: %[[#RESR:]] = cir.unary(minus, %[[#OPR]]) : !s32i, !s32i
704+
// CIR-NEXT: %[[#RESI:]] = cir.unary(minus, %[[#OPI]]) : !s32i, !s32i
705+
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#RESR]], %[[#RESI]] : !s32i -> !cir.complex<!s32i>
706+
707+
// LLVM: %[[#OPR:]] = extractvalue { double, double } %{{.+}}, 0
708+
// LLVM-NEXT: %[[#OPI:]] = extractvalue { double, double } %{{.+}}, 1
709+
// LLVM-NEXT: %[[#RESR:]] = fneg double %[[#OPR]]
710+
// LLVM-NEXT: %[[#RESI:]] = fneg double %[[#OPI]]
711+
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#RESR]], 0
712+
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#RESI]], 1
713+
714+
// LLVM: %[[#OPR:]] = extractvalue { i32, i32 } %{{.+}}, 0
715+
// LLVM-NEXT: %[[#OPI:]] = extractvalue { i32, i32 } %{{.+}}, 1
716+
// LLVM-NEXT: %[[#RESR:]] = sub i32 0, %[[#OPR]]
717+
// LLVM-NEXT: %[[#RESI:]] = sub i32 0, %[[#OPI]]
718+
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#RESR]], 0
719+
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#RESI]], 1
720+
721+
// CHECK: }
722+
723+
void unary_not() {
724+
cd1 = ~cd1;
725+
ci1 = ~ci1;
726+
}
727+
728+
// CLANG: @unary_not
729+
// CPPLANG: @_Z9unary_notv
730+
731+
// CIRGEN: %{{.+}} = cir.unary(not, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
732+
// CIRGEN: %{{.+}} = cir.unary(not, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>
733+
734+
// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
735+
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
736+
// CIR-NEXT: %[[#RESI:]] = cir.unary(minus, %[[#OPI]]) : !cir.double, !cir.double
737+
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#OPR]], %[[#RESI]] : !cir.double -> !cir.complex<!cir.double>
738+
739+
// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
740+
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
741+
// CIR-NEXT: %[[#RESI:]] = cir.unary(minus, %[[#OPI]]) : !s32i, !s32i
742+
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#OPR]], %[[#RESI]] : !s32i -> !cir.complex<!s32i>
743+
744+
// LLVM: %[[#OPR:]] = extractvalue { double, double } %{{.+}}, 0
745+
// LLVM-NEXT: %[[#OPI:]] = extractvalue { double, double } %{{.+}}, 1
746+
// LLVM-NEXT: %[[#RESI:]] = fneg double %[[#OPI]]
747+
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#OPR]], 0
748+
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#RESI]], 1
749+
750+
// LLVM: %[[#OPR:]] = extractvalue { i32, i32 } %{{.+}}, 0
751+
// LLVM-NEXT: %[[#OPI:]] = extractvalue { i32, i32 } %{{.+}}, 1
752+
// LLVM-NEXT: %[[#RESI:]] = sub i32 0, %[[#OPI]]
753+
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#OPR]], 0
754+
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#RESI]], 1
755+
756+
// CHECK: }
757+
758+
void builtin_conj() {
759+
cd1 = __builtin_conj(cd1);
760+
}
761+
762+
// CLANG: @builtin_conj
763+
// CPPLANG: @_Z12builtin_conjv
764+
765+
// CIRGEN: %{{.+}} = cir.unary(not, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
766+
767+
// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
768+
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
769+
// CIR-NEXT: %[[#RESI:]] = cir.unary(minus, %[[#OPI]]) : !cir.double, !cir.double
770+
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#OPR]], %[[#RESI]] : !cir.double -> !cir.complex<!cir.double>
771+
772+
// LLVM: %[[#OPR:]] = extractvalue { double, double } %{{.+}}, 0
773+
// LLVM-NEXT: %[[#OPI:]] = extractvalue { double, double } %{{.+}}, 1
774+
// LLVM-NEXT: %[[#RESI:]] = fneg double %[[#OPI]]
775+
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#OPR]], 0
776+
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#RESI]], 1
777+
778+
// CHECK: }

0 commit comments

Comments
 (0)