38
38
#include " mlir/include/mlir/IR/Attributes.h" // from @llvm-project
39
39
#include " mlir/include/mlir/IR/Builders.h" // from @llvm-project
40
40
#include " mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
41
+ #include " mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
41
42
#include " mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
42
43
#include " mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
43
44
#include " mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
@@ -153,7 +154,7 @@ class SecretToCKKSTypeConverter
153
154
154
155
auto ciphertext = lwe::LWECiphertextType::get (
155
156
ctx,
156
- lwe::ApplicationDataAttr::get (ctx, type. getValueType () ,
157
+ lwe::ApplicationDataAttr::get (ctx, valueTy ,
157
158
lwe::NoOverflowAttr::get (ctx)),
158
159
lwe::PlaintextSpaceAttr::get (
159
160
ctx, plaintextRing,
@@ -164,25 +165,60 @@ class SecretToCKKSTypeConverter
164
165
lwe::KeyAttr::get (ctx, 0 ),
165
166
lwe::ModulusChainAttr::get (ctx, moduliChain, level));
166
167
167
- // Return a single ciphertext if inputs are packed into a single
168
- // ciphertext SIMD slot or the secret value type is a scalar.
169
- if (this ->packTensorInSlots_ || !isa<TensorType>(valueTy)) {
170
- return ciphertext;
168
+ // Return a single ciphertext if the input is a scalar.
169
+ if (!isa<TensorType>(valueTy)) return ciphertext;
170
+
171
+ // The input is a tensor type.
172
+ assert (dyn_cast<RankedTensorType>(valueTy) &&
173
+ " expected ranked tensor type" );
174
+ auto tensorTy = cast<RankedTensorType>(valueTy);
175
+ // If the input is packed into a ciphertext SIMD slots (i.e. it is a tensor
176
+ // of shape NxciphertextSize) then return Nxciphertext.
177
+ if (this ->packTensorInSlots_ ) {
178
+ Type underlyingTy;
179
+ if (tensorTy.getRank () == 1 ) {
180
+ underlyingTy = valueTy;
181
+ auto ciphertext = lwe::LWECiphertextType::get (
182
+ ctx,
183
+ lwe::ApplicationDataAttr::get (ctx, underlyingTy,
184
+ lwe::NoOverflowAttr::get (ctx)),
185
+ lwe::PlaintextSpaceAttr::get (
186
+ ctx, plaintextRing,
187
+ lwe::InverseCanonicalEncodingAttr::get (ctx, scale)),
188
+ lwe::CiphertextSpaceAttr::get (
189
+ ctx, getRlweRNSRingWithLevel (ring_, level),
190
+ lwe::LweEncryptionType::mix, dimension),
191
+ lwe::KeyAttr::get (ctx, 0 ),
192
+ lwe::ModulusChainAttr::get (ctx, moduliChain, level));
193
+ return ciphertext;
194
+ }
195
+ assert (tensorTy.getRank () == 2 && " expected rank 1 or 2 tensor" );
196
+ underlyingTy = RankedTensorType::get (tensorTy.getShape ().drop_front (),
197
+ tensorTy.getElementType ());
198
+ auto ciphertext = lwe::LWECiphertextType::get (
199
+ ctx,
200
+ lwe::ApplicationDataAttr::get (ctx, underlyingTy,
201
+ lwe::NoOverflowAttr::get (ctx)),
202
+ lwe::PlaintextSpaceAttr::get (
203
+ ctx, plaintextRing,
204
+ lwe::InverseCanonicalEncodingAttr::get (ctx, scale)),
205
+ lwe::CiphertextSpaceAttr::get (ctx,
206
+ getRlweRNSRingWithLevel (ring_, level),
207
+ lwe::LweEncryptionType::mix, dimension),
208
+ lwe::KeyAttr::get (ctx, 0 ),
209
+ lwe::ModulusChainAttr::get (ctx, moduliChain, level));
210
+ return RankedTensorType::get (tensorTy.getShape ().drop_back (), ciphertext);
171
211
}
172
212
// If the input IR does not contain aligned ciphertexts, we will not
173
213
// pack tensors into ciphertext SIMD slots, so tensors are converted
174
214
// into tensors of RLWE ciphertexts.
175
- assert (dyn_cast<RankedTensorType>(valueTy) &&
176
- " expected ranked tensor type" );
177
- auto scalarType = cast<RankedTensorType>(valueTy).getElementType ();
178
215
ciphertext = lwe::LWECiphertextType::get (
179
216
ctx,
180
- lwe::ApplicationDataAttr::get (ctx, scalarType ,
217
+ lwe::ApplicationDataAttr::get (ctx, getElementTypeOrSelf (valueTy) ,
181
218
lwe::NoOverflowAttr::get (ctx)),
182
219
ciphertext.getPlaintextSpace (), ciphertext.getCiphertextSpace (),
183
220
ciphertext.getKey (), ciphertext.getModulusChain ());
184
- return RankedTensorType::get (cast<RankedTensorType>(valueTy).getShape (),
185
- ciphertext);
221
+ return RankedTensorType::get (tensorTy.getShape (), ciphertext);
186
222
}
187
223
188
224
private:
@@ -265,6 +301,83 @@ class SecretGenericTensorInsertConversion
265
301
}
266
302
};
267
303
304
+ class SecretGenericTensorExpandConversion
305
+ : public SecretGenericOpConversion<tensor::ExpandShapeOp,
306
+ tensor::ExpandShapeOp> {
307
+ public:
308
+ using SecretGenericOpConversion<
309
+ tensor::ExpandShapeOp, tensor::ExpandShapeOp>::SecretGenericOpConversion;
310
+
311
+ FailureOr<Operation*> matchAndRewriteInner (
312
+ secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
313
+ ArrayRef<NamedAttribute> attributes,
314
+ ContextAwareConversionPatternRewriter& rewriter) const override {
315
+ // We expect this operation to occur when dropping unit dimensions in order
316
+ // to allow rotation ops to operate on 1-D tensors.
317
+ SliceVerificationResult res = isRankReducedType (
318
+ cast<ShapedType>(
319
+ cast<secret::SecretType>(op.getResultTypes ()[0 ]).getValueType ()),
320
+ cast<ShapedType>(
321
+ cast<secret::SecretType>(op.getOperandTypes ()[0 ]).getValueType ()));
322
+ if (res != SliceVerificationResult::Success) {
323
+ return rewriter.notifyMatchFailure (
324
+ op, " expected input type to be a rank reduced type of the output" );
325
+ }
326
+ if (!isa<lwe::LWECiphertextType>(inputs[0 ].getType ())) {
327
+ return rewriter.notifyMatchFailure (
328
+ op, " expected input that was expanded to be of type RLWE ciphertext" );
329
+ }
330
+
331
+ if (!isa<RankedTensorType>(outputTypes[0 ])) {
332
+ return rewriter.notifyMatchFailure (
333
+ op, " expected expanded output to be a ranked tensor" );
334
+ }
335
+ return rewriter
336
+ .replaceOpWithNewOp <tensor::FromElementsOp>(op, outputTypes, inputs)
337
+ .getOperation ();
338
+ }
339
+ };
340
+
341
+ class SecretGenericTensorCollapseConversion
342
+ : public SecretGenericOpConversion<tensor::CollapseShapeOp,
343
+ tensor::CollapseShapeOp> {
344
+ public:
345
+ using SecretGenericOpConversion<
346
+ tensor::CollapseShapeOp,
347
+ tensor::CollapseShapeOp>::SecretGenericOpConversion;
348
+
349
+ FailureOr<Operation*> matchAndRewriteInner (
350
+ secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
351
+ ArrayRef<NamedAttribute> attributes,
352
+ ContextAwareConversionPatternRewriter& rewriter) const override {
353
+ // We expect this operation to occur when dropping unit dimensions in order
354
+ // to allow rotation ops to operate on 1-D tensors.
355
+ SliceVerificationResult res = isRankReducedType (
356
+ cast<ShapedType>(
357
+ cast<secret::SecretType>(op.getOperandTypes ()[0 ]).getValueType ()),
358
+ cast<ShapedType>(
359
+ cast<secret::SecretType>(op.getResultTypes ()[0 ]).getValueType ()));
360
+ if (res != SliceVerificationResult::Success) {
361
+ return rewriter.notifyMatchFailure (
362
+ op, " expected input type to be a rank reduced type of the output" );
363
+ }
364
+ if (!isa<RankedTensorType>(inputs[0 ].getType ())) {
365
+ return rewriter.notifyMatchFailure (
366
+ op, " expected input that was collapsed to be a ranked tensor" );
367
+ }
368
+ if (!isa<lwe::LWECiphertextType>(outputTypes[0 ])) {
369
+ return rewriter.notifyMatchFailure (
370
+ op, " expected collapsed output to be of type RLWE ciphertext" );
371
+ }
372
+
373
+ Value idx = rewriter.create <arith::ConstantIndexOp>(op.getLoc (), 0 );
374
+ return rewriter
375
+ .replaceOpWithNewOp <tensor::ExtractOp>(op, outputTypes[0 ], inputs[0 ],
376
+ idx)
377
+ .getOperation ();
378
+ }
379
+ };
380
+
268
381
bool hasSecretOperandsOrResults (Operation* op) {
269
382
return llvm::any_of (op->getOperands (),
270
383
[](Value operand) {
@@ -354,7 +467,8 @@ struct SecretToCKKS : public impl::SecretToCKKSBase<SecretToCKKS> {
354
467
target.addDynamicallyLegalOp <func::CallOp>(
355
468
[&](Operation* op) { return typeConverter.isLegal (op); });
356
469
target.addDynamicallyLegalOp <tensor::ExtractOp, tensor::ExtractSliceOp,
357
- tensor::InsertOp>(
470
+ tensor::InsertOp, tensor::ExpandShapeOp,
471
+ tensor::CollapseShapeOp>(
358
472
[&](Operation* op) { return typeConverter.isLegal (op); });
359
473
360
474
target.markUnknownOpDynamicallyLegal (
@@ -394,6 +508,8 @@ struct SecretToCKKS : public impl::SecretToCKKSBase<SecretToCKKS> {
394
508
SecretGenericOpLevelReduceConversion<ckks::LevelReduceOp>,
395
509
SecretGenericTensorExtractConversion,
396
510
SecretGenericTensorInsertConversion,
511
+ SecretGenericTensorCollapseConversion,
512
+ SecretGenericTensorExpandConversion,
397
513
ConvertAnyContextAware<affine::AffineForOp>,
398
514
ConvertAnyContextAware<affine::AffineYieldOp>,
399
515
ConvertAnyContextAware<tensor::ExtractSliceOp>,
0 commit comments