2
2
3
3
#include < cmath>
4
4
#include < cstdint>
5
+ #include < memory>
6
+ #include < optional>
5
7
6
8
#include " lib/Dialect/TensorExt/IR/TensorExtOps.h"
7
- #include " llvm/include/llvm/Support/Debug.h" // from @llvm-project
8
- #include " mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
9
- #include " mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
10
- #include " mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
11
- #include " mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
12
- #include " mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
13
- #include " mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
14
- #include " mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project
15
- #include " mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
16
- #include " mlir/include/mlir/IR/Value.h" // from @llvm-project
17
- #include " mlir/include/mlir/Support/LLVM.h" // from @llvm-project
18
- #include " mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
9
+ #include " lib/Kernel/ArithmeticDag.h"
10
+ #include " lib/Kernel/IRMaterializingVisitor.h"
11
+ #include " lib/Kernel/KernelImplementation.h"
12
+ #include " llvm/include/llvm/Support/Debug.h" // from @llvm-project
13
+ #include " mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
14
+ #include " mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
15
+ #include " mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
16
+ #include " mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
17
+ #include " mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project
18
+ #include " mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
19
+ #include " mlir/include/mlir/IR/Value.h" // from @llvm-project
20
+ #include " mlir/include/mlir/Support/LLVM.h" // from @llvm-project
21
+ #include " mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
19
22
20
23
#define DEBUG_TYPE " implement-rotate-and-reduce"
21
24
22
25
namespace mlir {
23
26
namespace heir {
24
27
namespace tensor_ext {
25
28
29
+ using ::mlir::heir::kernel::ArithmeticDagNode;
30
+ using ::mlir::heir::kernel::implementRotateAndReduce;
31
+ using ::mlir::heir::kernel::IRMaterializingVisitor;
32
+ using ::mlir::heir::kernel::SSAValue;
33
+
26
34
#define GEN_PASS_DEF_IMPLEMENTROTATEANDREDUCE
27
35
#include " lib/Dialect/TensorExt/Transforms/Passes.h.inc"
28
36
29
- // TODO(#2136): Add a better way to test the correctness of this kernel.
30
37
LogicalResult convertRotateAndReduceOp (RotateAndReduceOp op) {
31
38
LLVM_DEBUG (llvm::dbgs () << " Converting tensor_ext.rotate_and_reduce op: "
32
39
<< op << " \n " );
33
- if (!op.getPlaintexts ()) {
34
- // TODO(#2122): Implement the case where we accumulate the ciphertext slot
35
- // values.
36
- return op->emitOpError () << " rotate and reduce not implemented yet for "
37
- " ciphertext value accumulation" ;
38
- }
39
-
40
- IRRewriter rewriter (op.getContext ());
41
40
TypedValue<RankedTensorType> input = op.getTensor ();
42
- TypedValue<RankedTensorType> plaintexts = op.getPlaintexts ();
43
41
unsigned steps = op.getSteps ().getZExtValue ();
44
42
unsigned period = op.getPeriod ().getZExtValue ();
43
+ std::shared_ptr<ArithmeticDagNode<SSAValue>> implementedKernel;
44
+ SSAValue vectorLeaf (input);
45
45
46
- StringRef mulOpName = isa<IntegerType>(input. getType (). getElementType ())
47
- ? " arith.muli "
48
- : " arith.mulf " ;
49
- StringRef addOpName = isa<IntegerType>(input. getType (). getElementType ())
50
- ? " arith.addi "
51
- : " arith.addf " ;
46
+ if (!op. getPlaintexts ()) {
47
+ implementedKernel = implementRotateAndReduce (
48
+ vectorLeaf, std::optional<SSAValue>(), period, steps) ;
49
+ }
50
+
51
+ TypedValue<RankedTensorType> plaintexts = op. getPlaintexts () ;
52
52
53
- // Use a value of sqrt(n) as the baby step / giant step size.
53
+ // Validate divisibility of step size
54
54
auto babySteps = static_cast <int64_t >(std::floor (std::sqrt (steps)));
55
55
unsigned giantSteps = steps / babySteps;
56
56
if (giantSteps * babySteps != steps) {
@@ -64,78 +64,16 @@ LogicalResult convertRotateAndReduceOp(RotateAndReduceOp op) {
64
64
<< steps << " with babySteps= " << babySteps
65
65
<< " and giantSteps= " << giantSteps << " \n " );
66
66
67
- // Compute sqrt(n) ciphertext rotations of the input as baby-steps.
68
- rewriter.setInsertionPointAfter (op);
69
- SmallVector<Value> babyStepVals;
70
- babyStepVals.push_back (input);
71
- for (int64_t i = 1 ; i < babySteps; ++i) {
72
- babyStepVals.push_back (rewriter
73
- .create <tensor_ext::RotateOp>(
74
- op->getLoc (), input,
75
- rewriter.create <arith::ConstantIndexOp>(
76
- op->getLoc (), period * i))
77
- .getResult ());
78
- }
79
-
80
- unsigned plaintextSize = plaintexts.getType ().getRank ();
81
- SmallVector<OpFoldResult> offsets (plaintextSize, rewriter.getIndexAttr (0 ));
82
- SmallVector<OpFoldResult> sliceSizes;
83
- sliceSizes.reserve (plaintextSize);
84
- sliceSizes.push_back (rewriter.getIndexAttr (1 ));
85
- for (int64_t i = 1 ; i < plaintextSize; ++i) {
86
- sliceSizes.push_back (
87
- rewriter.getIndexAttr (plaintexts.getType ().getDimSize (i)));
88
- }
89
- SmallVector<OpFoldResult> unitStrides (plaintextSize,
90
- rewriter.getIndexAttr (1 ));
91
-
92
- // Compute the inner baby step sums.
93
- Value result;
94
- for (unsigned k = 0 ; k < giantSteps; ++k) {
95
- Value innerSum;
96
- auto rotationIndex = rewriter.create <arith::ConstantIndexOp>(
97
- op->getLoc (), -babySteps * k * period);
98
- for (unsigned j = 0 ; j < babySteps; ++j) {
99
- offsets[0 ] = rewriter.getIndexAttr (j + k * babySteps * period);
100
- Value rotatedPlaintext = rewriter.create <tensor_ext::RotateOp>(
101
- op->getLoc (),
102
- rewriter.create <tensor::ExtractSliceOp>(op->getLoc (), input.getType (),
103
- plaintexts, offsets,
104
- sliceSizes, unitStrides),
105
- rotationIndex);
106
- Value multiplied =
107
- rewriter
108
- .create (OperationState (op->getLoc (), mulOpName,
109
- {rotatedPlaintext, babyStepVals[j]},
110
- {rotatedPlaintext.getType ()}))
111
- ->getResults ()[0 ];
112
- if (!innerSum) {
113
- innerSum = multiplied;
114
- } else {
115
- innerSum = rewriter
116
- .create (OperationState (op->getLoc (), addOpName,
117
- {innerSum, multiplied},
118
- {innerSum.getType ()}))
119
- ->getResults ()[0 ];
120
- }
121
- }
122
-
123
- auto rotatedSum = rewriter.create <tensor_ext::RotateOp>(
124
- op->getLoc (), innerSum,
125
- rewriter.create <arith::ConstantIndexOp>(op->getLoc (),
126
- period * k * babySteps));
127
- if (!result) {
128
- result = rotatedSum;
129
- } else {
130
- result =
131
- rewriter
132
- .create (OperationState (op->getLoc (), addOpName,
133
- {result, rotatedSum}, {result.getType ()}))
134
- ->getResults ()[0 ];
135
- }
136
- }
67
+ auto plaintextsLeaf = std::optional<SSAValue>(plaintexts);
68
+ implementedKernel =
69
+ implementRotateAndReduce (vectorLeaf, plaintextsLeaf, period, steps);
137
70
138
- rewriter.replaceOp (op, result);
71
+ IRRewriter rewriter (op.getContext ());
72
+ rewriter.setInsertionPointAfter (op);
73
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
74
+ IRMaterializingVisitor visitor (b, input.getType ());
75
+ Value finalOutput = implementedKernel->visit (visitor);
76
+ rewriter.replaceOp (op, finalOutput);
139
77
return success ();
140
78
}
141
79
0 commit comments