Skip to content

[CIR] Add support for unary complex operations #750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,8 +756,12 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_conjl:
case Builtin::BIconj:
case Builtin::BIconjf:
case Builtin::BIconjl:
llvm_unreachable("NYI");
case Builtin::BIconjl: {
mlir::Value ComplexVal = buildComplexExpr(E->getArg(0));
mlir::Value Conj = builder.createUnaryOp(
getLoc(E->getExprLoc()), mlir::cir::UnaryOpKind::Not, ComplexVal);
return RValue::getComplex(Conj);
}

case Builtin::BI__builtin___CFStringMakeConstantString:
case Builtin::BI__builtin___NSStringMakeConstantString:
Expand Down
70 changes: 57 additions & 13 deletions clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,20 +152,12 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
mlir::Value VisitUnaryDeref(const Expr *E) { llvm_unreachable("NYI"); }

mlir::Value VisitUnaryPlus(const UnaryOperator *E,
QualType PromotionType = QualType()) {
llvm_unreachable("NYI");
}
mlir::Value VisitPlus(const UnaryOperator *E, QualType PromotionType) {
llvm_unreachable("NYI");
}
QualType PromotionType = QualType());
mlir::Value VisitPlus(const UnaryOperator *E, QualType PromotionType);
mlir::Value VisitUnaryMinus(const UnaryOperator *E,
QualType PromotionType = QualType()) {
llvm_unreachable("NYI");
}
mlir::Value VisitMinus(const UnaryOperator *E, QualType PromotionType) {
llvm_unreachable("NYI");
}
mlir::Value VisitUnaryNot(const UnaryOperator *E) { llvm_unreachable("NYI"); }
QualType PromotionType = QualType());
mlir::Value VisitMinus(const UnaryOperator *E, QualType PromotionType);
mlir::Value VisitUnaryNot(const UnaryOperator *E);
// LNot,Real,Imag never return complex.
mlir::Value VisitUnaryExtension(const UnaryOperator *E) {
return Visit(E->getSubExpr());
Expand Down Expand Up @@ -508,6 +500,58 @@ mlir::Value ComplexExprEmitter::VisitCallExpr(const CallExpr *E) {
return CGF.buildCallExpr(E).getComplexVal();
}

mlir::Value ComplexExprEmitter::VisitUnaryPlus(const UnaryOperator *E,
QualType PromotionType) {
QualType promotionTy = PromotionType.isNull()
? getPromotionType(E->getSubExpr()->getType())
: PromotionType;
mlir::Value result = VisitPlus(E, promotionTy);
if (!promotionTy.isNull())
return CGF.buildUnPromotedValue(result, E->getSubExpr()->getType());
return result;
}

mlir::Value ComplexExprEmitter::VisitPlus(const UnaryOperator *E,
QualType PromotionType) {
mlir::Value Op;
if (!PromotionType.isNull())
Op = CGF.buildPromotedComplexExpr(E->getSubExpr(), PromotionType);
else
Op = Visit(E->getSubExpr());

return Builder.createUnaryOp(CGF.getLoc(E->getExprLoc()),
mlir::cir::UnaryOpKind::Plus, Op);
}

mlir::Value ComplexExprEmitter::VisitUnaryMinus(const UnaryOperator *E,
QualType PromotionType) {
QualType promotionTy = PromotionType.isNull()
? getPromotionType(E->getSubExpr()->getType())
: PromotionType;
mlir::Value result = VisitMinus(E, promotionTy);
if (!promotionTy.isNull())
return CGF.buildUnPromotedValue(result, E->getSubExpr()->getType());
return result;
}

mlir::Value ComplexExprEmitter::VisitMinus(const UnaryOperator *E,
QualType PromotionType) {
mlir::Value Op;
if (!PromotionType.isNull())
Op = CGF.buildPromotedComplexExpr(E->getSubExpr(), PromotionType);
else
Op = Visit(E->getSubExpr());

return Builder.createUnaryOp(CGF.getLoc(E->getExprLoc()),
mlir::cir::UnaryOpKind::Minus, Op);
}

mlir::Value ComplexExprEmitter::VisitUnaryNot(const UnaryOperator *E) {
mlir::Value Op = Visit(E->getSubExpr());
return Builder.createUnaryOp(CGF.getLoc(E->getExprLoc()),
mlir::cir::UnaryOpKind::Not, Op);
}

ComplexExprEmitter::BinOpInfo
ComplexExprEmitter::buildBinOps(const BinaryOperator *E, QualType PromotionTy) {
BinOpInfo Ops{CGF.getLoc(E->getExprLoc())};
Expand Down
51 changes: 49 additions & 2 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
void runOnOperation() override;

void runOnOp(Operation *op);
void lowerUnaryOp(UnaryOp op);
void lowerBinOp(BinOp op);
void lowerComplexBinOp(ComplexBinOp op);
void lowerThreeWayCmpOp(CmpThreeWayOp op);
Expand Down Expand Up @@ -347,6 +348,50 @@ void LoweringPreparePass::lowerVAArgOp(VAArgOp op) {
return;
}

void LoweringPreparePass::lowerUnaryOp(UnaryOp op) {
auto ty = op.getType();
if (!mlir::isa<mlir::cir::ComplexType>(ty))
return;

auto loc = op.getLoc();
auto opKind = op.getKind();
assert((opKind == mlir::cir::UnaryOpKind::Plus ||
opKind == mlir::cir::UnaryOpKind::Minus ||
opKind == mlir::cir::UnaryOpKind::Not) &&
"invalid unary op kind on complex numbers");

CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);

auto operand = op.getInput();

auto operandReal = builder.createComplexReal(loc, operand);
auto operandImag = builder.createComplexImag(loc, operand);

mlir::Value resultReal;
mlir::Value resultImag;
switch (opKind) {
case mlir::cir::UnaryOpKind::Plus:
case mlir::cir::UnaryOpKind::Minus:
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
resultImag = builder.createUnaryOp(loc, opKind, operandImag);
break;

case mlir::cir::UnaryOpKind::Not:
resultReal = operandReal;
resultImag =
builder.createUnaryOp(loc, mlir::cir::UnaryOpKind::Minus, operandImag);
break;

default:
llvm_unreachable("unsupported complex unary op kind");
}

auto result = builder.createComplexCreate(loc, resultReal, resultImag);
op.replaceAllUsesWith(result);
op.erase();
}

void LoweringPreparePass::lowerBinOp(BinOp op) {
auto ty = op.getType();
if (!mlir::isa<mlir::cir::ComplexType>(ty))
Expand Down Expand Up @@ -939,7 +984,9 @@ void LoweringPreparePass::lowerIterEndOp(IterEndOp op) {
}

void LoweringPreparePass::runOnOp(Operation *op) {
if (auto bin = dyn_cast<BinOp>(op)) {
if (auto unary = dyn_cast<UnaryOp>(op)) {
lowerUnaryOp(unary);
} else if (auto bin = dyn_cast<BinOp>(op)) {
lowerBinOp(bin);
} else if (auto complexBin = dyn_cast<ComplexBinOp>(op)) {
lowerComplexBinOp(complexBin);
Expand Down Expand Up @@ -980,7 +1027,7 @@ void LoweringPreparePass::runOnOperation() {
SmallVector<Operation *> opsToTransform;

op->walk([&](Operation *op) {
if (isa<BinOp, ComplexBinOp, CmpThreeWayOp, VAArgOp, GlobalOp,
if (isa<UnaryOp, BinOp, ComplexBinOp, CmpThreeWayOp, VAArgOp, GlobalOp,
DynamicCastOp, StdFindOp, IterEndOp, IterBeginOp, ArrayCtor,
ArrayDtor, mlir::cir::FuncOp>(op))
opsToTransform.push_back(op);
Expand Down
131 changes: 131 additions & 0 deletions clang/test/CIR/CodeGen/complex-arithmetic.c
Original file line number Diff line number Diff line change
Expand Up @@ -645,3 +645,134 @@ void div_assign() {
// CIRGEN-FULL: %{{.+}} = cir.complex.binop div %{{.+}}, %{{.+}} range(full) : !cir.complex<!s32i>

// CHECK: }

void unary_plus() {
cd1 = +cd1;
ci1 = +ci1;
}

// CLANG: @unary_plus
// CPPLANG: @_Z10unary_plusv

// CIRGEN: %{{.+}} = cir.unary(plus, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(plus, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#RESR:]] = cir.unary(plus, %[[#OPR]]) : !cir.double, !cir.double
// CIR-NEXT: %[[#RESI:]] = cir.unary(plus, %[[#OPI]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#RESR]], %[[#RESI]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#RESR:]] = cir.unary(plus, %[[#OPR]]) : !s32i, !s32i
// CIR-NEXT: %[[#RESI:]] = cir.unary(plus, %[[#OPI]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#RESR]], %[[#RESI]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#OPR:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#OPI:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#OPR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#OPI]], 1

// LLVM: %[[#OPR:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#OPI:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#OPR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#OPI]], 1

// CHECK: }

void unary_minus() {
cd1 = -cd1;
ci1 = -ci1;
}

// CLANG: @unary_minus
// CPPLANG: @_Z11unary_minusv

// CIRGEN: %{{.+}} = cir.unary(minus, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(minus, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#RESR:]] = cir.unary(minus, %[[#OPR]]) : !cir.double, !cir.double
// CIR-NEXT: %[[#RESI:]] = cir.unary(minus, %[[#OPI]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#RESR]], %[[#RESI]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#RESR:]] = cir.unary(minus, %[[#OPR]]) : !s32i, !s32i
// CIR-NEXT: %[[#RESI:]] = cir.unary(minus, %[[#OPI]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#RESR]], %[[#RESI]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#OPR:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#OPI:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#RESR:]] = fneg double %[[#OPR]]
// LLVM-NEXT: %[[#RESI:]] = fneg double %[[#OPI]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#RESR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#RESI]], 1

// LLVM: %[[#OPR:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#OPI:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#RESR:]] = sub i32 0, %[[#OPR]]
// LLVM-NEXT: %[[#RESI:]] = sub i32 0, %[[#OPI]]
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#RESR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#RESI]], 1

// CHECK: }

void unary_not() {
cd1 = ~cd1;
ci1 = ~ci1;
}

// CLANG: @unary_not
// CPPLANG: @_Z9unary_notv

// CIRGEN: %{{.+}} = cir.unary(not, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(not, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#RESI:]] = cir.unary(minus, %[[#OPI]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#OPR]], %[[#RESI]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#RESI:]] = cir.unary(minus, %[[#OPI]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#OPR]], %[[#RESI]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#OPR:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#OPI:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#RESI:]] = fneg double %[[#OPI]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#OPR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#RESI]], 1

// LLVM: %[[#OPR:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#OPI:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#RESI:]] = sub i32 0, %[[#OPI]]
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#OPR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#RESI]], 1

// CHECK: }

void builtin_conj() {
cd1 = __builtin_conj(cd1);
}

// CLANG: @builtin_conj
// CPPLANG: @_Z12builtin_conjv

// CIRGEN: %{{.+}} = cir.unary(not, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>

// CIR: %[[#OPR:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#OPI:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#RESI:]] = cir.unary(minus, %[[#OPI]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#OPR]], %[[#RESI]] : !cir.double -> !cir.complex<!cir.double>

// LLVM: %[[#OPR:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#OPI:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#RESI:]] = fneg double %[[#OPI]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#OPR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#RESI]], 1

// CHECK: }
Loading