Skip to content

Commit ecb9cf0

Browse files
asraacopybara-github
authored andcommitted
Support tensor reshape op in secret-to-ckks
PiperOrigin-RevId: 800485015
1 parent cd0dd55 commit ecb9cf0

File tree

12 files changed

+397
-75
lines changed

12 files changed

+397
-75
lines changed

lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.cpp

Lines changed: 128 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
3939
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
4040
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
41+
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
4142
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
4243
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
4344
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
@@ -153,7 +154,7 @@ class SecretToCKKSTypeConverter
153154

154155
auto ciphertext = lwe::LWECiphertextType::get(
155156
ctx,
156-
lwe::ApplicationDataAttr::get(ctx, type.getValueType(),
157+
lwe::ApplicationDataAttr::get(ctx, valueTy,
157158
lwe::NoOverflowAttr::get(ctx)),
158159
lwe::PlaintextSpaceAttr::get(
159160
ctx, plaintextRing,
@@ -164,25 +165,60 @@ class SecretToCKKSTypeConverter
164165
lwe::KeyAttr::get(ctx, 0),
165166
lwe::ModulusChainAttr::get(ctx, moduliChain, level));
166167

167-
// Return a single ciphertext if inputs are packed into a single
168-
// ciphertext SIMD slot or the secret value type is a scalar.
169-
if (this->packTensorInSlots_ || !isa<TensorType>(valueTy)) {
170-
return ciphertext;
168+
// Return a single ciphertext if the input is a scalar.
169+
if (!isa<TensorType>(valueTy)) return ciphertext;
170+
171+
// The input is a tensor type.
172+
assert(dyn_cast<RankedTensorType>(valueTy) &&
173+
"expected ranked tensor type");
174+
auto tensorTy = cast<RankedTensorType>(valueTy);
175+
// If the input is packed into a ciphertext SIMD slots (i.e. it is a tensor
176+
// of shape NxciphertextSize) then return Nxciphertext.
177+
if (this->packTensorInSlots_) {
178+
Type underlyingTy;
179+
if (tensorTy.getRank() == 1) {
180+
underlyingTy = valueTy;
181+
auto ciphertext = lwe::LWECiphertextType::get(
182+
ctx,
183+
lwe::ApplicationDataAttr::get(ctx, underlyingTy,
184+
lwe::NoOverflowAttr::get(ctx)),
185+
lwe::PlaintextSpaceAttr::get(
186+
ctx, plaintextRing,
187+
lwe::InverseCanonicalEncodingAttr::get(ctx, scale)),
188+
lwe::CiphertextSpaceAttr::get(
189+
ctx, getRlweRNSRingWithLevel(ring_, level),
190+
lwe::LweEncryptionType::mix, dimension),
191+
lwe::KeyAttr::get(ctx, 0),
192+
lwe::ModulusChainAttr::get(ctx, moduliChain, level));
193+
return ciphertext;
194+
}
195+
assert(tensorTy.getRank() == 2 && "expected rank 1 or 2 tensor");
196+
underlyingTy = RankedTensorType::get(tensorTy.getShape().drop_front(),
197+
tensorTy.getElementType());
198+
auto ciphertext = lwe::LWECiphertextType::get(
199+
ctx,
200+
lwe::ApplicationDataAttr::get(ctx, underlyingTy,
201+
lwe::NoOverflowAttr::get(ctx)),
202+
lwe::PlaintextSpaceAttr::get(
203+
ctx, plaintextRing,
204+
lwe::InverseCanonicalEncodingAttr::get(ctx, scale)),
205+
lwe::CiphertextSpaceAttr::get(ctx,
206+
getRlweRNSRingWithLevel(ring_, level),
207+
lwe::LweEncryptionType::mix, dimension),
208+
lwe::KeyAttr::get(ctx, 0),
209+
lwe::ModulusChainAttr::get(ctx, moduliChain, level));
210+
return RankedTensorType::get(tensorTy.getShape().drop_back(), ciphertext);
171211
}
172212
// If the input IR does not contain aligned ciphertexts, we will not
173213
// pack tensors into ciphertext SIMD slots, so tensors are converted
174214
// into tensors of RLWE ciphertexts.
175-
assert(dyn_cast<RankedTensorType>(valueTy) &&
176-
"expected ranked tensor type");
177-
auto scalarType = cast<RankedTensorType>(valueTy).getElementType();
178215
ciphertext = lwe::LWECiphertextType::get(
179216
ctx,
180-
lwe::ApplicationDataAttr::get(ctx, scalarType,
217+
lwe::ApplicationDataAttr::get(ctx, getElementTypeOrSelf(valueTy),
181218
lwe::NoOverflowAttr::get(ctx)),
182219
ciphertext.getPlaintextSpace(), ciphertext.getCiphertextSpace(),
183220
ciphertext.getKey(), ciphertext.getModulusChain());
184-
return RankedTensorType::get(cast<RankedTensorType>(valueTy).getShape(),
185-
ciphertext);
221+
return RankedTensorType::get(tensorTy.getShape(), ciphertext);
186222
}
187223

188224
private:
@@ -265,6 +301,83 @@ class SecretGenericTensorInsertConversion
265301
}
266302
};
267303

304+
class SecretGenericTensorExpandConversion
305+
: public SecretGenericOpConversion<tensor::ExpandShapeOp,
306+
tensor::ExpandShapeOp> {
307+
public:
308+
using SecretGenericOpConversion<
309+
tensor::ExpandShapeOp, tensor::ExpandShapeOp>::SecretGenericOpConversion;
310+
311+
FailureOr<Operation*> matchAndRewriteInner(
312+
secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
313+
ArrayRef<NamedAttribute> attributes,
314+
ContextAwareConversionPatternRewriter& rewriter) const override {
315+
// We expect this operation to occur when dropping unit dimensions in order
316+
// to allow rotation ops to operate on 1-D tensors.
317+
SliceVerificationResult res = isRankReducedType(
318+
cast<ShapedType>(
319+
cast<secret::SecretType>(op.getResultTypes()[0]).getValueType()),
320+
cast<ShapedType>(
321+
cast<secret::SecretType>(op.getOperandTypes()[0]).getValueType()));
322+
if (res != SliceVerificationResult::Success) {
323+
return rewriter.notifyMatchFailure(
324+
op, "expected input type to be a rank reduced type of the output");
325+
}
326+
if (!isa<lwe::LWECiphertextType>(inputs[0].getType())) {
327+
return rewriter.notifyMatchFailure(
328+
op, "expected input that was expanded to be of type RLWE ciphertext");
329+
}
330+
331+
if (!isa<RankedTensorType>(outputTypes[0])) {
332+
return rewriter.notifyMatchFailure(
333+
op, "expected expanded output to be a ranked tensor");
334+
}
335+
return rewriter
336+
.replaceOpWithNewOp<tensor::FromElementsOp>(op, outputTypes, inputs)
337+
.getOperation();
338+
}
339+
};
340+
341+
class SecretGenericTensorCollapseConversion
342+
: public SecretGenericOpConversion<tensor::CollapseShapeOp,
343+
tensor::CollapseShapeOp> {
344+
public:
345+
using SecretGenericOpConversion<
346+
tensor::CollapseShapeOp,
347+
tensor::CollapseShapeOp>::SecretGenericOpConversion;
348+
349+
FailureOr<Operation*> matchAndRewriteInner(
350+
secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
351+
ArrayRef<NamedAttribute> attributes,
352+
ContextAwareConversionPatternRewriter& rewriter) const override {
353+
// We expect this operation to occur when dropping unit dimensions in order
354+
// to allow rotation ops to operate on 1-D tensors.
355+
SliceVerificationResult res = isRankReducedType(
356+
cast<ShapedType>(
357+
cast<secret::SecretType>(op.getOperandTypes()[0]).getValueType()),
358+
cast<ShapedType>(
359+
cast<secret::SecretType>(op.getResultTypes()[0]).getValueType()));
360+
if (res != SliceVerificationResult::Success) {
361+
return rewriter.notifyMatchFailure(
362+
op, "expected input type to be a rank reduced type of the output");
363+
}
364+
if (!isa<RankedTensorType>(inputs[0].getType())) {
365+
return rewriter.notifyMatchFailure(
366+
op, "expected input that was collapsed to be a ranked tensor");
367+
}
368+
if (!isa<lwe::LWECiphertextType>(outputTypes[0])) {
369+
return rewriter.notifyMatchFailure(
370+
op, "expected collapsed output to be of type RLWE ciphertext");
371+
}
372+
373+
Value idx = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
374+
return rewriter
375+
.replaceOpWithNewOp<tensor::ExtractOp>(op, outputTypes[0], inputs[0],
376+
idx)
377+
.getOperation();
378+
}
379+
};
380+
268381
bool hasSecretOperandsOrResults(Operation* op) {
269382
return llvm::any_of(op->getOperands(),
270383
[](Value operand) {
@@ -354,7 +467,8 @@ struct SecretToCKKS : public impl::SecretToCKKSBase<SecretToCKKS> {
354467
target.addDynamicallyLegalOp<func::CallOp>(
355468
[&](Operation* op) { return typeConverter.isLegal(op); });
356469
target.addDynamicallyLegalOp<tensor::ExtractOp, tensor::ExtractSliceOp,
357-
tensor::InsertOp>(
470+
tensor::InsertOp, tensor::ExpandShapeOp,
471+
tensor::CollapseShapeOp>(
358472
[&](Operation* op) { return typeConverter.isLegal(op); });
359473

360474
target.markUnknownOpDynamicallyLegal(
@@ -394,6 +508,8 @@ struct SecretToCKKS : public impl::SecretToCKKSBase<SecretToCKKS> {
394508
SecretGenericOpLevelReduceConversion<ckks::LevelReduceOp>,
395509
SecretGenericTensorExtractConversion,
396510
SecretGenericTensorInsertConversion,
511+
SecretGenericTensorCollapseConversion,
512+
SecretGenericTensorExpandConversion,
397513
ConvertAnyContextAware<affine::AffineForOp>,
398514
ConvertAnyContextAware<affine::AffineYieldOp>,
399515
ConvertAnyContextAware<tensor::ExtractSliceOp>,

lib/Pipelines/ArithmeticPipelineRegistration.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "lib/Dialect/TensorExt/Conversions/TensorExtToTensor/TensorExtToTensor.h"
2323
#include "lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.h"
2424
#include "lib/Dialect/TensorExt/Transforms/FoldConvertLayoutIntoAssignLayout.h"
25+
#include "lib/Dialect/TensorExt/Transforms/ImplementRotateAndReduce.h"
2526
#include "lib/Dialect/TensorExt/Transforms/InsertRotate.h"
2627
#include "lib/Dialect/TensorExt/Transforms/RotateAndReduce.h"
2728
#include "lib/Pipelines/PipelineRegistration.h"
@@ -152,6 +153,7 @@ void mlirToSecretArithmeticPipelineBuilder(
152153
convertToCiphertextSemanticsOptions.ciphertextSize = options.ciphertextDegree;
153154
pm.addPass(
154155
createConvertToCiphertextSemantics(convertToCiphertextSemanticsOptions));
156+
pm.addPass(tensor_ext::createImplementRotateAndReduce());
155157

156158
mathToPolynomialApproximationBuilder(pm);
157159

lib/Pipelines/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ cc_library(
119119
"@heir//lib/Dialect/TensorExt/Conversions/TensorExtToTensor",
120120
"@heir//lib/Dialect/TensorExt/Transforms:CollapseInsertionChains",
121121
"@heir//lib/Dialect/TensorExt/Transforms:FoldConvertLayoutIntoAssignLayout",
122+
"@heir//lib/Dialect/TensorExt/Transforms:ImplementRotateAndReduce",
122123
"@heir//lib/Dialect/TensorExt/Transforms:InsertRotate",
123124
"@heir//lib/Dialect/TensorExt/Transforms:RotateAndReduce",
124125
"@heir//lib/Transforms/AddClientInterface",

lib/Transforms/ConvertToCiphertextSemantics/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ cc_library(
1616
":pass_inc_gen",
1717
"@heir//lib/Dialect/Secret/IR:SecretPatterns",
1818
"@heir//lib/Dialect/TensorExt/IR:Dialect",
19+
"@heir//lib/Transforms/DropUnitDims",
1920
"@heir//lib/Utils",
2021
"@heir//lib/Utils:AffineMapUtils",
2122
"@heir//lib/Utils:AttributeUtils",

lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "lib/Dialect/TensorExt/IR/TensorExtOps.h"
1717
#include "lib/Transforms/ConvertToCiphertextSemantics/AssignLayout.h"
1818
#include "lib/Transforms/ConvertToCiphertextSemantics/TypeConversion.h"
19+
#include "lib/Transforms/DropUnitDims/DropUnitDims.h"
1920
#include "lib/Utils/AffineMapUtils.h"
2021
#include "lib/Utils/AttributeUtils.h"
2122
#include "lib/Utils/ContextAwareConversionUtils.h"
@@ -27,6 +28,7 @@
2728
#include "lib/Utils/Utils.h"
2829
#include "llvm/include/llvm/ADT/ArrayRef.h" // from @llvm-project
2930
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
31+
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
3032
#include "llvm/include/llvm/ADT/StringExtras.h" // from @llvm-project
3133
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
3234
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
@@ -47,13 +49,13 @@
4749
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
4850
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
4951
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
52+
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
5053
#include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project
5154
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
5255
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
5356
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
5457
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
5558
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
56-
#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project
5759

5860
#define DEBUG_TYPE "convert-to-ciphertext-semantics"
5961

@@ -1357,7 +1359,6 @@ class ConvertExpandShape
13571359
if (!sourceLayout) {
13581360
return op.emitError() << "failed to fetch new layout attribute for input";
13591361
}
1360-
op.dump();
13611362

13621363
if (resultType != srcType) {
13631364
return rewriter.notifyMatchFailure(
@@ -1385,6 +1386,99 @@ class ConvertExpandShape
13851386
}
13861387
};
13871388

1389+
struct DropRotateUnitDims : OpRewritePattern<tensor_ext::RotateOp> {
1390+
using OpRewritePattern<tensor_ext::RotateOp>::OpRewritePattern;
1391+
1392+
LogicalResult matchAndRewrite(tensor_ext::RotateOp rotateOp,
1393+
PatternRewriter& rewriter) const override {
1394+
SmallVector<int64_t> operandUnitDims =
1395+
getUnitDims(rotateOp.getTensor().getType());
1396+
if (operandUnitDims.empty()) {
1397+
LLVM_DEBUG(llvm::dbgs() << "no unit dims to drop");
1398+
return failure();
1399+
}
1400+
1401+
SmallVector<Value> collapsedOperands =
1402+
collapseOperands(rewriter, {rotateOp.getTensor()}, operandUnitDims);
1403+
1404+
tensor_ext::RotateOp collapsedOp = tensor_ext::RotateOp::create(
1405+
rewriter, rotateOp.getLoc(), collapsedOperands[0], rotateOp.getShift());
1406+
rewriter.replaceOp(rotateOp, expandResult(rewriter, collapsedOp.getResult(),
1407+
rotateOp.getOutput().getType(),
1408+
operandUnitDims));
1409+
return success();
1410+
}
1411+
};
1412+
1413+
struct DropRotateAndReduceUnitDims
1414+
: OpRewritePattern<tensor_ext::RotateAndReduceOp> {
1415+
using OpRewritePattern<tensor_ext::RotateAndReduceOp>::OpRewritePattern;
1416+
1417+
LogicalResult matchAndRewrite(tensor_ext::RotateAndReduceOp rotateOp,
1418+
PatternRewriter& rewriter) const override {
1419+
SmallVector<int64_t> operandUnitDims =
1420+
getUnitDims(rotateOp.getTensor().getType());
1421+
if (operandUnitDims.empty()) {
1422+
LLVM_DEBUG(llvm::dbgs() << "no unit dims to drop");
1423+
return failure();
1424+
}
1425+
1426+
SmallVector<Value> collapsedOperands =
1427+
collapseOperands(rewriter, {rotateOp.getTensor()}, operandUnitDims);
1428+
1429+
auto collapsedOp = tensor_ext::RotateAndReduceOp::create(
1430+
rewriter, rotateOp.getLoc(), collapsedOperands[0],
1431+
rotateOp.getPlaintexts(), rotateOp.getPeriod(), rotateOp.getSteps());
1432+
rewriter.replaceOp(rotateOp, expandResult(rewriter, collapsedOp.getResult(),
1433+
rotateOp.getOutput().getType(),
1434+
operandUnitDims));
1435+
return success();
1436+
}
1437+
};
1438+
1439+
struct DropElementwiseUnitDims : OpTraitRewritePattern<OpTrait::Elementwise> {
1440+
explicit DropElementwiseUnitDims(MLIRContext* context)
1441+
: OpTraitRewritePattern(context) {}
1442+
1443+
LogicalResult matchAndRewrite(mlir::Operation* op,
1444+
PatternRewriter& rewriter) const override {
1445+
// Ensure that all operands and results have the same type.
1446+
SmallVector<Type> operandAndResultTypes =
1447+
llvm::to_vector(op->getOperandTypes());
1448+
operandAndResultTypes.append(op->getResultTypes().begin(),
1449+
op->getResultTypes().end());
1450+
if (!llvm::all_equal(operandAndResultTypes) || op->getNumOperands() == 0 ||
1451+
op->getNumResults() != 1) {
1452+
return failure();
1453+
}
1454+
1455+
auto tensorType = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1456+
if (!tensorType) {
1457+
return failure();
1458+
}
1459+
1460+
SmallVector<int64_t> operandUnitDims = getUnitDims(tensorType);
1461+
if (operandUnitDims.empty()) {
1462+
LLVM_DEBUG(llvm::dbgs() << "no unit dims to drop");
1463+
return failure();
1464+
}
1465+
1466+
SmallVector<Value> collapsedOperands = collapseOperands(
1467+
rewriter, llvm::to_vector(op->getOperands()), operandUnitDims);
1468+
1469+
Type resultType = collapsedOperands[0].getType();
1470+
Operation* collapsedOp = rewriter.create(OperationState(
1471+
op->getLoc(), op->getName().getStringRef(), collapsedOperands,
1472+
resultType, op->getAttrs(), op->getSuccessors()));
1473+
1474+
rewriter.replaceOp(
1475+
op, expandResult(rewriter, collapsedOp->getResults()[0],
1476+
cast<RankedTensorType>(op->getResult(0).getType()),
1477+
operandUnitDims));
1478+
return success();
1479+
}
1480+
};
1481+
13881482
struct ConvertToCiphertextSemantics
13891483
: impl::ConvertToCiphertextSemanticsBase<ConvertToCiphertextSemantics> {
13901484
using ConvertToCiphertextSemanticsBase::ConvertToCiphertextSemanticsBase;
@@ -1425,6 +1519,10 @@ struct ConvertToCiphertextSemantics
14251519
// Note ConvertAssignLayout generates tensor.concat
14261520
RewritePatternSet cleanupPatterns2(context);
14271521
tensor::populateDecomposeTensorConcatPatterns(cleanupPatterns2);
1522+
// Drop unit dimensions for tensor_ext ops that require 1-D tensors (i.e.
1523+
// rotation ops) and elementwise ops.
1524+
cleanupPatterns2.add<DropRotateUnitDims, DropRotateAndReduceUnitDims,
1525+
DropElementwiseUnitDims>(context);
14281526
// Folding here will remove any unrealized conversion cast ops that were
14291527
// inserted to persist new layouts.
14301528
if (failed(applyPatternsGreedily(module, std::move(cleanupPatterns2)))) {

0 commit comments

Comments
 (0)