Skip to content

Commit f257f95

Browse files
committed
Replace rotate-and-reduce with new lowering
Adds support for mul reductions in rotate-and-reduce kernel impl
1 parent d168e37 commit f257f95

File tree

10 files changed

+142
-71
lines changed

10 files changed

+142
-71
lines changed

lib/Analysis/RotationAnalysis/RotationAnalysis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
1818
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
1919

20+
#define DEBUG_TYPE "rotation-analysis"
21+
2022
namespace mlir {
2123
namespace heir {
2224
namespace rotation_analysis {

lib/Analysis/RotationAnalysis/RotationAnalysis.h

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <vector>
1111

1212
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
13-
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
1413
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
1514
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
1615
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
@@ -19,8 +18,6 @@
1918
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
2019
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
2120

22-
#define DEBUG_TYPE "rotation-analysis"
23-
2421
namespace mlir {
2522
namespace heir {
2623
namespace rotation_analysis {
@@ -92,8 +89,6 @@ class PartialReduction {
9289
// first element.
9390
reduction.addRotation(0);
9491

95-
LLVM_DEBUG(llvm::dbgs()
96-
<< "Initializing at " << tensor << " with rotations [0]\n");
9792
return reduction;
9893
}
9994

@@ -107,11 +102,6 @@ class PartialReduction {
107102
"Internal state of RotationAnalysis is broken; tensor having saved "
108103
"value should be impossible");
109104

110-
LLVM_DEBUG({
111-
llvm::dbgs() << "Rotating\n\t";
112-
lhs.print(llvm::dbgs());
113-
llvm::dbgs() << " by " << shift;
114-
});
115105
PartialReduction shifted;
116106
shifted.tensor = lhs.tensor;
117107
shifted.opName = lhs.opName;
@@ -124,11 +114,6 @@ class PartialReduction {
124114
for (auto index : lhs.accessedIndices) {
125115
shifted.addRotation((index + shift) % size);
126116
}
127-
LLVM_DEBUG({
128-
llvm::dbgs() << " to\n\t";
129-
shifted.print(llvm::dbgs());
130-
llvm::dbgs() << "\n";
131-
});
132117
return shifted;
133118
}
134119

@@ -202,15 +187,6 @@ class PartialReduction {
202187
for (auto value : rhs.savedValues) {
203188
merged.savedValues.push_back(value);
204189
}
205-
LLVM_DEBUG({
206-
llvm::dbgs() << "Joining\n\t";
207-
lhs.print(llvm::dbgs());
208-
llvm::dbgs() << " and\n\t";
209-
rhs.print(llvm::dbgs());
210-
llvm::dbgs() << " to get\n\t";
211-
merged.print(llvm::dbgs());
212-
llvm::dbgs() << "\n";
213-
});
214190
return merged;
215191
}
216192

@@ -254,15 +230,6 @@ class PartialReduction {
254230
}
255231
merged.savedValues = lhs.savedValues;
256232
merged.savedValues.push_back(rhs);
257-
LLVM_DEBUG({
258-
llvm::dbgs() << "Saving\n\t";
259-
rhs.print(llvm::dbgs());
260-
llvm::dbgs() << " inside\n\t";
261-
lhs.print(llvm::dbgs());
262-
llvm::dbgs() << " to get\n\t";
263-
merged.print(llvm::dbgs());
264-
llvm::dbgs() << "\n";
265-
});
266233
return merged;
267234
}
268235

lib/Dialect/TensorExt/IR/TensorExtOps.td

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def TensorExt_RotateAndReduceOp : TensorExt_Op<"rotate_and_reduce",[Pure, AllTyp
139139
This op reduces products of a plaintext with a periodically rotated
140140
tensor.
141141

142-
In generality, the reduction performs is
142+
In almost full generality, the reduction performed is
143143

144144
\[
145145
\sum_{i \in [0, n]} p(P, T*i) \cdot rotate(v, T*i)
@@ -167,6 +167,11 @@ def TensorExt_RotateAndReduceOp : TensorExt_Op<"rotate_and_reduce",[Pure, AllTyp
167167
`n = |v|` so that the reduction is simply a sum of all rotation of the
168168
ciphertext.
169169

170+
If `reduceOp` is set to an MLIR operation name (e.g., `arith.mulf`), then
171+
the reduction operation is modified to use that operation instead of a sum.
172+
The chosen op must be one of `arith.muli`, `arith.mulf`, `arith.addi`,
173+
or `arith.addf`.
174+
170175
Efficient lowerings of this operation can use the Baby-Step / Giant-Step
171176
approach from [Faster Homomorphic Linear Transformations in
172177
HElib](https://eprint.iacr.org/2018/244.pdf) to reduce the number of
@@ -177,10 +182,34 @@ def TensorExt_RotateAndReduceOp : TensorExt_Op<"rotate_and_reduce",[Pure, AllTyp
177182
ins AnyRankedTensor:$tensor,
178183
Optional<AnyRankedTensor>:$plaintexts,
179184
IndexAttr:$period,
180-
IndexAttr:$steps
185+
IndexAttr:$steps,
186+
OptionalAttr<Builtin_StringAttr>:$reduceOp
181187
);
182188
let results = (outs AnyRankedTensor:$output);
183189
let hasVerifier = 1;
190+
191+
let builders = [
192+
// Default builder for case of empty plaintexts
193+
OpBuilder<(ins
194+
"Value":$tensor, "int64_t":$period, "int64_t":$steps,
195+
"::llvm::StringRef":$reduceOp), [{
196+
return build(
197+
$_builder,
198+
$_state,
199+
tensor.getType(),
200+
ValueRange{tensor},
201+
{
202+
$_builder.getNamedAttr(
203+
"period", $_builder.getIndexAttr(period)),
204+
$_builder.getNamedAttr(
205+
"steps", $_builder.getIndexAttr(steps)),
206+
$_builder.getNamedAttr(
207+
"reduceOp", $_builder.getStringAttr(reduceOp))
208+
}
209+
);
210+
}]>
211+
];
212+
184213
// TODO(#2134): Add canonicalization patterns
185214
}
186215

lib/Dialect/TensorExt/Transforms/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ cc_library(
7373
"RotateAndReduce.h",
7474
],
7575
deps = [
76+
":ImplementRotateAndReduce",
7677
":pass_inc_gen",
7778
"@heir//lib/Analysis/RotationAnalysis",
7879
"@heir//lib/Dialect/TensorExt/IR:Dialect",

lib/Dialect/TensorExt/Transforms/ImplementRotateAndReduce.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,18 @@ LogicalResult convertRotateAndReduceOp(RotateAndReduceOp op) {
4242
unsigned period = op.getPeriod().getZExtValue();
4343
std::shared_ptr<ArithmeticDagNode<SSAValue>> implementedKernel;
4444
SSAValue vectorLeaf(input);
45+
std::optional<SSAValue> plaintextsLeaf = std::nullopt;
4546

46-
if (!op.getPlaintexts()) {
47-
implementedKernel = implementRotateAndReduce(
48-
vectorLeaf, std::optional<SSAValue>(), period, steps);
47+
if (op.getPlaintexts()) {
48+
plaintextsLeaf = std::optional<SSAValue>(op.getPlaintexts());
4949
}
5050

51-
TypedValue<RankedTensorType> plaintexts = op.getPlaintexts();
52-
auto plaintextsLeaf = std::optional<SSAValue>(plaintexts);
53-
implementedKernel =
54-
implementRotateAndReduce(vectorLeaf, plaintextsLeaf, period, steps);
55-
51+
std::string reduceOp = "arith.addi";
52+
if (op.getReduceOp().has_value() && *op.getReduceOp() != nullptr) {
53+
reduceOp = op.getReduceOp()->getValue().str();
54+
}
55+
implementedKernel = implementRotateAndReduce(vectorLeaf, plaintextsLeaf,
56+
period, steps, reduceOp);
5657
IRRewriter rewriter(op.getContext());
5758
rewriter.setInsertionPointAfter(op);
5859
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

lib/Dialect/TensorExt/Transforms/ImplementRotateAndReduce.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ namespace tensor_ext {
1313
#define GEN_PASS_DECL_IMPLEMENTROTATEANDREDUCE
1414
#include "lib/Dialect/TensorExt/Transforms/Passes.h.inc"
1515

16+
LogicalResult convertRotateAndReduceOp(RotateAndReduceOp op);
17+
1618
} // namespace tensor_ext
1719
} // namespace heir
1820
} // namespace mlir

lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#include "lib/Dialect/TensorExt/Transforms/RotateAndReduce.h"
22

3-
#include <cstdint>
4-
53
#include "lib/Analysis/RotationAnalysis/RotationAnalysis.h"
64
#include "lib/Dialect/TensorExt/IR/TensorExtOps.h"
5+
#include "lib/Dialect/TensorExt/Transforms/ImplementRotateAndReduce.h"
76
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
87
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
98
#include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project
@@ -13,13 +12,12 @@
1312
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
1413
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
1514
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
16-
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
1715
#include "mlir/include/mlir/IR/Iterators.h" // from @llvm-project
1816
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
1917
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
2018
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
2119

22-
#define DEBUG_NAME "rotate-and-reduce"
20+
#define DEBUG_TYPE "rotate-and-reduce"
2321

2422
namespace mlir {
2523
namespace heir {
@@ -28,13 +26,12 @@ namespace tensor_ext {
2826
#define GEN_PASS_DEF_ROTATEANDREDUCE
2927
#include "lib/Dialect/TensorExt/Transforms/Passes.h.inc"
3028

31-
/// A pass that searches for a length N sequence of binary operations that
29+
/// A pass that searches for a length N sequence of add operations that
3230
/// reduces a length N vector to a single scalar, and replaces it with a
3331
/// logarithmic number of rotations and binary operations.
3432
struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
3533
using RotateAndReduceBase::RotateAndReduceBase;
3634

37-
// TODO(#2123): Rewrite this to use the tensor_ext.rotate_and_reduce op.
3835
template <typename ArithOp>
3936
void tryReplaceRotations(ArithOp op,
4037
const rotation_analysis::PartialReduction& reduction,
@@ -43,17 +40,16 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
4340
<< "Trying to replace rotations ending in " << *op << "\n");
4441
auto b = ImplicitLocOpBuilder(op->getLoc(), op);
4542
auto tensor = reduction.getTensor();
46-
Operation* finalOp = nullptr;
4743
auto tensorShape =
4844
mlir::cast<RankedTensorType>(tensor.getType()).getShape();
49-
for (int64_t shiftSize = tensorShape[0] / 2; shiftSize > 0;
50-
shiftSize /= 2) {
51-
auto rotatedTensor = tensor_ext::RotateOp::create(
52-
b, tensor, arith::ConstantOp::create(b, b.getIndexAttr(shiftSize)));
53-
auto addOp = ArithOp::create(b, tensor, rotatedTensor);
54-
finalOp = addOp;
55-
tensor = addOp->getResult(0);
56-
}
45+
46+
// Get the operation name for the reduce_op attribute
47+
auto rotateAndReduceOp = tensor_ext::RotateAndReduceOp::create(
48+
b, tensor,
49+
/*period=*/1,
50+
/*steps=*/tensorShape[0],
51+
/*reduceOp=*/op->getName().getStringRef());
52+
Operation* finalOp = rotateAndReduceOp;
5753

5854
[[maybe_unused]] auto* parentOp = op->getParentOp();
5955
if (extraction) {
@@ -69,6 +65,12 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
6965
}
7066
if (finalOp) op->replaceAllUsesWith(finalOp);
7167
LLVM_DEBUG(llvm::dbgs() << "Post-replacement: " << *parentOp << "\n");
68+
69+
// Convert the rotate_and_reduce op to its implementation immediately
70+
if (failed(convertRotateAndReduceOp(rotateAndReduceOp))) {
71+
LLVM_DEBUG(llvm::dbgs() << "Failed to convert rotate_and_reduce op\n");
72+
return;
73+
}
7274
}
7375

7476
void runOnOperation() override {
@@ -98,20 +100,22 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
98100

99101
for (const auto& reduction :
100102
rotationAnalysis.getRootedReductionsAt(result)) {
101-
if (reduction.isComplete()) {
103+
if (reduction.isComplete() &&
104+
cast<RankedTensorType>(reduction.getTensor().getType())
105+
.getNumElements() > 1) {
102106
llvm::TypeSwitch<Operation&>(*op)
103107
.Case<arith::AddIOp>([&](auto arithOp) {
104108
tryReplaceRotations<arith::AddIOp>(arithOp, reduction,
105109
extraction);
106110
})
107-
.Case<arith::MulIOp>([&](auto arithOp) {
108-
tryReplaceRotations<arith::MulIOp>(arithOp, reduction,
109-
extraction);
110-
})
111111
.Case<arith::AddFOp>([&](auto arithOp) {
112112
tryReplaceRotations<arith::AddFOp>(arithOp, reduction,
113113
extraction);
114114
})
115+
.Case<arith::MulIOp>([&](auto arithOp) {
116+
tryReplaceRotations<arith::MulIOp>(arithOp, reduction,
117+
extraction);
118+
})
115119
.Case<arith::MulFOp>([&](auto arithOp) {
116120
tryReplaceRotations<arith::MulFOp>(arithOp, reduction,
117121
extraction);

lib/Kernel/KernelImplementation.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,30 @@ template <typename T>
132132
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
133133
std::shared_ptr<ArithmeticDagNode<T>>>
134134
implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
135-
int64_t period, int64_t steps) {
135+
int64_t period, int64_t steps,
136+
const std::string& reduceOp = "arith.addi") {
136137
using NodeTy = ArithmeticDagNode<T>;
137138
auto vectorDag = NodeTy::leaf(vector);
138139

140+
auto performReduction = [&](std::shared_ptr<NodeTy> left,
141+
std::shared_ptr<NodeTy> right) {
142+
if (reduceOp == "arith.addi" || reduceOp == "arith.addf") {
143+
return NodeTy::add(left, right);
144+
}
145+
146+
if (reduceOp == "arith.muli" || reduceOp == "arith.mulf") {
147+
return NodeTy::mul(left, right);
148+
}
149+
150+
// Default to add for unknown operations
151+
return NodeTy::add(left, right);
152+
};
153+
139154
if (!plaintexts.has_value()) {
140155
for (int64_t shiftSize = steps / 2; shiftSize > 0; shiftSize /= 2) {
141156
auto rotated = NodeTy::leftRotate(vectorDag, shiftSize * period);
142-
auto added = NodeTy::add(vectorDag, rotated);
143-
vectorDag = added;
157+
auto reduced = performReduction(vectorDag, rotated);
158+
vectorDag = reduced;
144159
}
145160
return vectorDag;
146161
}
@@ -191,12 +206,13 @@ implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
191206
auto rotatedPlaintext =
192207
NodeTy::leftRotate(plaintext, plaintextRotationAmount);
193208
auto multiplied = NodeTy::mul(rotatedPlaintext, babyStepVals[i]);
194-
innerSum =
195-
innerSum == nullptr ? multiplied : NodeTy::add(innerSum, multiplied);
209+
innerSum = innerSum == nullptr ? multiplied
210+
: performReduction(innerSum, multiplied);
196211
}
197212

198213
auto rotatedSum = NodeTy::leftRotate(innerSum, period * j * giantStepSize);
199-
result = result == nullptr ? rotatedSum : NodeTy::add(result, rotatedSum);
214+
result =
215+
result == nullptr ? rotatedSum : performReduction(result, rotatedSum);
200216
}
201217

202218
return result;

lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,8 @@ struct ConvertLinalgMatvecNewLayout
944944
auto rotateAndReduceOp = rewriter.create<tensor_ext::RotateAndReduceOp>(
945945
op.getLoc(), packedVectorType, packedVector, packedMatrix,
946946
/*period=*/rewriter.getIndexAttr(1),
947-
/*steps=*/rewriter.getIndexAttr(numRotations));
947+
/*steps=*/rewriter.getIndexAttr(numRotations),
948+
/*reduce_op=*/rewriter.getStringAttr("arith.addf"));
948949
rotateAndReduceOp->setAttr(kLayoutAttrName, layoutAttr);
949950
setMaterializedAttr(rotateAndReduceOp);
950951

0 commit comments

Comments
 (0)