Skip to content

Commit 16ae9fe

Browse files
Merge pull request #2154 from j2kun:test-kernel-impls
PiperOrigin-RevId: 800497389
2 parents e766174 + f4063d5 commit 16ae9fe

22 files changed

+931
-446
lines changed

lib/Dialect/TensorExt/Transforms/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,10 @@ cc_library(
132132
deps = [
133133
":pass_inc_gen",
134134
"@heir//lib/Dialect/TensorExt/IR:Dialect",
135+
"@heir//lib/Kernel:ArithmeticDag",
136+
"@heir//lib/Kernel:IRMaterializingVisitor",
137+
"@heir//lib/Kernel:KernelImplementation",
135138
"@llvm-project//llvm:Support",
136-
"@llvm-project//mlir:Analysis",
137139
"@llvm-project//mlir:ArithDialect",
138140
"@llvm-project//mlir:IR",
139141
"@llvm-project//mlir:Pass",

lib/Dialect/TensorExt/Transforms/ImplementRotateAndReduce.cpp

Lines changed: 38 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -2,55 +2,55 @@
22

33
#include <cmath>
44
#include <cstdint>
5+
#include <memory>
6+
#include <optional>
57

68
#include "lib/Dialect/TensorExt/IR/TensorExtOps.h"
7-
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
8-
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
9-
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
10-
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
11-
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
12-
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
13-
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
14-
#include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project
15-
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
16-
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
17-
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
18-
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
9+
#include "lib/Kernel/ArithmeticDag.h"
10+
#include "lib/Kernel/IRMaterializingVisitor.h"
11+
#include "lib/Kernel/KernelImplementation.h"
12+
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
13+
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
14+
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
15+
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
16+
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
17+
#include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project
18+
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
19+
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
20+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
21+
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
1922

2023
#define DEBUG_TYPE "implement-rotate-and-reduce"
2124

2225
namespace mlir {
2326
namespace heir {
2427
namespace tensor_ext {
2528

29+
using ::mlir::heir::kernel::ArithmeticDagNode;
30+
using ::mlir::heir::kernel::implementRotateAndReduce;
31+
using ::mlir::heir::kernel::IRMaterializingVisitor;
32+
using ::mlir::heir::kernel::SSAValue;
33+
2634
#define GEN_PASS_DEF_IMPLEMENTROTATEANDREDUCE
2735
#include "lib/Dialect/TensorExt/Transforms/Passes.h.inc"
2836

29-
// TODO(#2136): Add a better way to test the correctness of this kernel.
3037
LogicalResult convertRotateAndReduceOp(RotateAndReduceOp op) {
3138
LLVM_DEBUG(llvm::dbgs() << "Converting tensor_ext.rotate_and_reduce op: "
3239
<< op << "\n");
33-
if (!op.getPlaintexts()) {
34-
// TODO(#2122): Implement the case where we accumulate the ciphertext slot
35-
// values.
36-
return op->emitOpError() << "rotate and reduce not implemented yet for "
37-
"ciphertext value accumulation";
38-
}
39-
40-
IRRewriter rewriter(op.getContext());
4140
TypedValue<RankedTensorType> input = op.getTensor();
42-
TypedValue<RankedTensorType> plaintexts = op.getPlaintexts();
4341
unsigned steps = op.getSteps().getZExtValue();
4442
unsigned period = op.getPeriod().getZExtValue();
43+
std::shared_ptr<ArithmeticDagNode<SSAValue>> implementedKernel;
44+
SSAValue vectorLeaf(input);
4545

46-
StringRef mulOpName = isa<IntegerType>(input.getType().getElementType())
47-
? "arith.muli"
48-
: "arith.mulf";
49-
StringRef addOpName = isa<IntegerType>(input.getType().getElementType())
50-
? "arith.addi"
51-
: "arith.addf";
46+
if (!op.getPlaintexts()) {
47+
implementedKernel = implementRotateAndReduce(
48+
vectorLeaf, std::optional<SSAValue>(), period, steps);
49+
}
50+
51+
TypedValue<RankedTensorType> plaintexts = op.getPlaintexts();
5252

53-
// Use a value of sqrt(n) as the baby step / giant step size.
53+
// Validate divisibility of step size
5454
auto babySteps = static_cast<int64_t>(std::floor(std::sqrt(steps)));
5555
unsigned giantSteps = steps / babySteps;
5656
if (giantSteps * babySteps != steps) {
@@ -64,78 +64,16 @@ LogicalResult convertRotateAndReduceOp(RotateAndReduceOp op) {
6464
<< steps << " with babySteps= " << babySteps
6565
<< " and giantSteps= " << giantSteps << "\n");
6666

67-
// Compute sqrt(n) ciphertext rotations of the input as baby-steps.
68-
rewriter.setInsertionPointAfter(op);
69-
SmallVector<Value> babyStepVals;
70-
babyStepVals.push_back(input);
71-
for (int64_t i = 1; i < babySteps; ++i) {
72-
babyStepVals.push_back(rewriter
73-
.create<tensor_ext::RotateOp>(
74-
op->getLoc(), input,
75-
rewriter.create<arith::ConstantIndexOp>(
76-
op->getLoc(), period * i))
77-
.getResult());
78-
}
79-
80-
unsigned plaintextSize = plaintexts.getType().getRank();
81-
SmallVector<OpFoldResult> offsets(plaintextSize, rewriter.getIndexAttr(0));
82-
SmallVector<OpFoldResult> sliceSizes;
83-
sliceSizes.reserve(plaintextSize);
84-
sliceSizes.push_back(rewriter.getIndexAttr(1));
85-
for (int64_t i = 1; i < plaintextSize; ++i) {
86-
sliceSizes.push_back(
87-
rewriter.getIndexAttr(plaintexts.getType().getDimSize(i)));
88-
}
89-
SmallVector<OpFoldResult> unitStrides(plaintextSize,
90-
rewriter.getIndexAttr(1));
91-
92-
// Compute the inner baby step sums.
93-
Value result;
94-
for (unsigned k = 0; k < giantSteps; ++k) {
95-
Value innerSum;
96-
auto rotationIndex = rewriter.create<arith::ConstantIndexOp>(
97-
op->getLoc(), -babySteps * k * period);
98-
for (unsigned j = 0; j < babySteps; ++j) {
99-
offsets[0] = rewriter.getIndexAttr(j + k * babySteps * period);
100-
Value rotatedPlaintext = rewriter.create<tensor_ext::RotateOp>(
101-
op->getLoc(),
102-
rewriter.create<tensor::ExtractSliceOp>(op->getLoc(), input.getType(),
103-
plaintexts, offsets,
104-
sliceSizes, unitStrides),
105-
rotationIndex);
106-
Value multiplied =
107-
rewriter
108-
.create(OperationState(op->getLoc(), mulOpName,
109-
{rotatedPlaintext, babyStepVals[j]},
110-
{rotatedPlaintext.getType()}))
111-
->getResults()[0];
112-
if (!innerSum) {
113-
innerSum = multiplied;
114-
} else {
115-
innerSum = rewriter
116-
.create(OperationState(op->getLoc(), addOpName,
117-
{innerSum, multiplied},
118-
{innerSum.getType()}))
119-
->getResults()[0];
120-
}
121-
}
122-
123-
auto rotatedSum = rewriter.create<tensor_ext::RotateOp>(
124-
op->getLoc(), innerSum,
125-
rewriter.create<arith::ConstantIndexOp>(op->getLoc(),
126-
period * k * babySteps));
127-
if (!result) {
128-
result = rotatedSum;
129-
} else {
130-
result =
131-
rewriter
132-
.create(OperationState(op->getLoc(), addOpName,
133-
{result, rotatedSum}, {result.getType()}))
134-
->getResults()[0];
135-
}
136-
}
67+
auto plaintextsLeaf = std::optional<SSAValue>(plaintexts);
68+
implementedKernel =
69+
implementRotateAndReduce(vectorLeaf, plaintextsLeaf, period, steps);
13770

138-
rewriter.replaceOp(op, result);
71+
IRRewriter rewriter(op.getContext());
72+
rewriter.setInsertionPointAfter(op);
73+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
74+
IRMaterializingVisitor visitor(b, input.getType());
75+
Value finalOutput = implementedKernel->visit(visitor);
76+
rewriter.replaceOp(op, finalOutput);
13977
return success();
14078
}
14179

lib/Utils/ArithmeticDag.h renamed to lib/Kernel/ArithmeticDag.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33

44
#include <cassert>
55
#include <cstddef>
6+
#include <cstdint>
67
#include <memory>
78
#include <unordered_map>
89
#include <utility>
910
#include <variant>
1011

1112
namespace mlir {
1213
namespace heir {
14+
namespace kernel {
1315

1416
// This file contains a generic DAG structure that can be used for representing
1517
// arithmetic DAGs with leaf nodes of various types.
@@ -53,7 +55,7 @@ struct PowerNode {
5355
template <typename T>
5456
struct LeftRotateNode {
5557
std::shared_ptr<ArithmeticDagNode<T>> operand;
56-
size_t shift;
58+
int64_t shift;
5759
};
5860

5961
template <typename T>
@@ -141,7 +143,7 @@ struct ArithmeticDagNode {
141143
}
142144

143145
static std::shared_ptr<ArithmeticDagNode<T>> leftRotate(
144-
std::shared_ptr<ArithmeticDagNode<T>> tensor, size_t shift) {
146+
std::shared_ptr<ArithmeticDagNode<T>> tensor, int64_t shift) {
145147
assert(tensor && "invalid tensor for leftRotate");
146148
auto node =
147149
std::shared_ptr<ArithmeticDagNode<T>>(new ArithmeticDagNode<T>());
@@ -213,40 +215,49 @@ class CachingVisitor {
213215

214216
virtual ResultType operator()(const ConstantNode& node) {
215217
assert(false && "Visit logic for ConstantNode is not implemented.");
218+
return ResultType();
216219
}
217220

218221
virtual ResultType operator()(const LeafNode<T>& node) {
219222
assert(false && "Visit logic for LeafNode is not implemented.");
223+
return ResultType();
220224
}
221225

222226
virtual ResultType operator()(const AddNode<T>& node) {
223227
assert(false && "Visit logic for AddNode is not implemented.");
228+
return ResultType();
224229
}
225230

226231
virtual ResultType operator()(const SubtractNode<T>& node) {
227232
assert(false && "Visit logic for SubtractNode is not implemented.");
233+
return ResultType();
228234
}
229235

230236
virtual ResultType operator()(const MultiplyNode<T>& node) {
231237
assert(false && "Visit logic for MultiplyNode is not implemented.");
238+
return ResultType();
232239
}
233240

234241
virtual ResultType operator()(const PowerNode<T>& node) {
235242
assert(false && "Visit logic for PowerNode is not implemented.");
243+
return ResultType();
236244
}
237245

238246
virtual ResultType operator()(const LeftRotateNode<T>& node) {
239247
assert(false && "Visit logic for LeftRotateNode is not implemented.");
248+
return ResultType();
240249
}
241250

242251
virtual ResultType operator()(const ExtractNode<T>& node) {
243252
assert(false && "Visit logic for ExtractNode is not implemented.");
253+
return ResultType();
244254
}
245255

246256
private:
247257
std::unordered_map<const ArithmeticDagNode<T>*, ResultType> cache;
248258
};
249259

260+
} // namespace kernel
250261
} // namespace heir
251262
} // namespace mlir
252263

lib/Utils/ArithmeticDagTest.cpp renamed to lib/Kernel/ArithmeticDagTest.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
#include <string>
77

88
#include "gtest/gtest.h" // from @googletest
9-
#include "lib/Utils/ArithmeticDag.h"
9+
#include "lib/Kernel/ArithmeticDag.h"
1010

1111
namespace mlir {
1212
namespace heir {
13+
namespace kernel {
1314
namespace {
1415

1516
using StringLeavedDag = ArithmeticDagNode<std::string>;
@@ -154,5 +155,6 @@ TEST(ArithmeticDagTest, TestEvaluationVisitorSubstract) {
154155
}
155156

156157
} // namespace
158+
} // namespace kernel
157159
} // namespace heir
158160
} // namespace mlir

lib/Kernel/BUILD

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,38 @@ package(
66
default_visibility = ["//visibility:public"],
77
)
88

9+
cc_library(
10+
name = "ArithmeticDag",
11+
srcs = ["ArithmeticDag.h"],
12+
hdrs = ["ArithmeticDag.h"],
13+
)
14+
15+
cc_test(
16+
name = "ArithmeticDagTest",
17+
srcs = ["ArithmeticDagTest.cpp"],
18+
deps = [
19+
":ArithmeticDag",
20+
"@googletest//:gtest_main",
21+
],
22+
)
23+
24+
cc_library(
25+
name = "IRMaterializingVisitor",
26+
srcs = ["IRMaterializingVisitor.cpp"],
27+
hdrs = ["IRMaterializingVisitor.h"],
28+
deps = [
29+
":ArithmeticDag",
30+
":KernelImplementation",
31+
"@heir//lib/Dialect/TensorExt/IR:Dialect",
32+
"@heir//lib/Utils:MathUtils",
33+
"@llvm-project//llvm:Support",
34+
"@llvm-project//mlir:ArithDialect",
35+
"@llvm-project//mlir:IR",
36+
"@llvm-project//mlir:Support",
37+
"@llvm-project//mlir:TensorDialect",
38+
],
39+
)
40+
941
cc_library(
1042
name = "Kernel",
1143
srcs = ["Kernel.cpp"],
@@ -22,23 +54,37 @@ cc_library(
2254

2355
cc_library(
2456
name = "KernelImplementation",
25-
srcs = ["KernelImplementation.cpp"],
2657
hdrs = ["KernelImplementation.h"],
2758
deps = [
59+
":ArithmeticDag",
2860
":Kernel",
29-
"@heir//lib/Utils:ArithmeticDag",
3061
"@llvm-project//mlir:IR",
62+
"@llvm-project//mlir:Support",
63+
"@llvm-project//mlir:TensorDialect",
64+
],
65+
)
66+
67+
cc_library(
68+
name = "TestingUtils",
69+
srcs = ["TestingUtils.cpp"],
70+
hdrs = ["TestingUtils.h"],
71+
deps = [
72+
":ArithmeticDag",
73+
":KernelImplementation",
3174
],
3275
)
3376

3477
cc_test(
3578
name = "KernelImplementationTest",
36-
srcs = ["KernelImplementationTest.cpp"],
79+
srcs = [
80+
"KernelImplementationTest.cpp",
81+
"RotateAndReduceImplTest.cpp",
82+
],
3783
deps = [
84+
":ArithmeticDag",
3885
":Kernel",
3986
":KernelImplementation",
87+
":TestingUtils",
4088
"@googletest//:gtest_main",
41-
"@heir//lib/Utils:ArithmeticDag",
42-
"@llvm-project//mlir:IR",
4389
],
4490
)

0 commit comments

Comments
 (0)