|
16 | 16 | #include "lib/Dialect/TensorExt/IR/TensorExtOps.h"
|
17 | 17 | #include "lib/Transforms/ConvertToCiphertextSemantics/AssignLayout.h"
|
18 | 18 | #include "lib/Transforms/ConvertToCiphertextSemantics/TypeConversion.h"
|
| 19 | +#include "lib/Transforms/DropUnitDims/DropUnitDims.h" |
19 | 20 | #include "lib/Utils/AffineMapUtils.h"
|
20 | 21 | #include "lib/Utils/AttributeUtils.h"
|
21 | 22 | #include "lib/Utils/ContextAwareConversionUtils.h"
|
|
27 | 28 | #include "lib/Utils/Utils.h"
|
28 | 29 | #include "llvm/include/llvm/ADT/ArrayRef.h" // from @llvm-project
|
29 | 30 | #include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
|
| 31 | +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project |
30 | 32 | #include "llvm/include/llvm/ADT/StringExtras.h" // from @llvm-project
|
31 | 33 | #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
|
32 | 34 | #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
|
|
47 | 49 | #include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
|
48 | 50 | #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
|
49 | 51 | #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
|
| 52 | +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project |
50 | 53 | #include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project
|
51 | 54 | #include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
|
52 | 55 | #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
|
53 | 56 | #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
|
54 | 57 | #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
|
55 | 58 | #include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
|
56 |
| -#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project |
57 | 59 |
|
58 | 60 | #define DEBUG_TYPE "convert-to-ciphertext-semantics"
|
59 | 61 |
|
@@ -1357,7 +1359,6 @@ class ConvertExpandShape
|
1357 | 1359 | if (!sourceLayout) {
|
1358 | 1360 | return op.emitError() << "failed to fetch new layout attribute for input";
|
1359 | 1361 | }
|
1360 |
| - op.dump(); |
1361 | 1362 |
|
1362 | 1363 | if (resultType != srcType) {
|
1363 | 1364 | return rewriter.notifyMatchFailure(
|
@@ -1385,6 +1386,99 @@ class ConvertExpandShape
|
1385 | 1386 | }
|
1386 | 1387 | };
|
1387 | 1388 |
|
| 1389 | +struct DropRotateUnitDims : OpRewritePattern<tensor_ext::RotateOp> { |
| 1390 | + using OpRewritePattern<tensor_ext::RotateOp>::OpRewritePattern; |
| 1391 | + |
| 1392 | + LogicalResult matchAndRewrite(tensor_ext::RotateOp rotateOp, |
| 1393 | + PatternRewriter& rewriter) const override { |
| 1394 | + SmallVector<int64_t> operandUnitDims = |
| 1395 | + getUnitDims(rotateOp.getTensor().getType()); |
| 1396 | + if (operandUnitDims.empty()) { |
| 1397 | + LLVM_DEBUG(llvm::dbgs() << "no unit dims to drop"); |
| 1398 | + return failure(); |
| 1399 | + } |
| 1400 | + |
| 1401 | + SmallVector<Value> collapsedOperands = |
| 1402 | + collapseOperands(rewriter, {rotateOp.getTensor()}, operandUnitDims); |
| 1403 | + |
| 1404 | + tensor_ext::RotateOp collapsedOp = tensor_ext::RotateOp::create( |
| 1405 | + rewriter, rotateOp.getLoc(), collapsedOperands[0], rotateOp.getShift()); |
| 1406 | + rewriter.replaceOp(rotateOp, expandResult(rewriter, collapsedOp.getResult(), |
| 1407 | + rotateOp.getOutput().getType(), |
| 1408 | + operandUnitDims)); |
| 1409 | + return success(); |
| 1410 | + } |
| 1411 | +}; |
| 1412 | + |
| 1413 | +struct DropRotateAndReduceUnitDims |
| 1414 | + : OpRewritePattern<tensor_ext::RotateAndReduceOp> { |
| 1415 | + using OpRewritePattern<tensor_ext::RotateAndReduceOp>::OpRewritePattern; |
| 1416 | + |
| 1417 | + LogicalResult matchAndRewrite(tensor_ext::RotateAndReduceOp rotateOp, |
| 1418 | + PatternRewriter& rewriter) const override { |
| 1419 | + SmallVector<int64_t> operandUnitDims = |
| 1420 | + getUnitDims(rotateOp.getTensor().getType()); |
| 1421 | + if (operandUnitDims.empty()) { |
| 1422 | + LLVM_DEBUG(llvm::dbgs() << "no unit dims to drop"); |
| 1423 | + return failure(); |
| 1424 | + } |
| 1425 | + |
| 1426 | + SmallVector<Value> collapsedOperands = |
| 1427 | + collapseOperands(rewriter, {rotateOp.getTensor()}, operandUnitDims); |
| 1428 | + |
| 1429 | + auto collapsedOp = tensor_ext::RotateAndReduceOp::create( |
| 1430 | + rewriter, rotateOp.getLoc(), collapsedOperands[0], |
| 1431 | + rotateOp.getPlaintexts(), rotateOp.getPeriod(), rotateOp.getSteps()); |
| 1432 | + rewriter.replaceOp(rotateOp, expandResult(rewriter, collapsedOp.getResult(), |
| 1433 | + rotateOp.getOutput().getType(), |
| 1434 | + operandUnitDims)); |
| 1435 | + return success(); |
| 1436 | + } |
| 1437 | +}; |
| 1438 | + |
| 1439 | +struct DropElementwiseUnitDims : OpTraitRewritePattern<OpTrait::Elementwise> { |
| 1440 | + explicit DropElementwiseUnitDims(MLIRContext* context) |
| 1441 | + : OpTraitRewritePattern(context) {} |
| 1442 | + |
| 1443 | + LogicalResult matchAndRewrite(mlir::Operation* op, |
| 1444 | + PatternRewriter& rewriter) const override { |
| 1445 | + // Ensure that all operands and results have the same type. |
| 1446 | + SmallVector<Type> operandAndResultTypes = |
| 1447 | + llvm::to_vector(op->getOperandTypes()); |
| 1448 | + operandAndResultTypes.append(op->getResultTypes().begin(), |
| 1449 | + op->getResultTypes().end()); |
| 1450 | + if (!llvm::all_equal(operandAndResultTypes) || op->getNumOperands() == 0 || |
| 1451 | + op->getNumResults() != 1) { |
| 1452 | + return failure(); |
| 1453 | + } |
| 1454 | + |
| 1455 | + auto tensorType = dyn_cast<RankedTensorType>(op->getOperand(0).getType()); |
| 1456 | + if (!tensorType) { |
| 1457 | + return failure(); |
| 1458 | + } |
| 1459 | + |
| 1460 | + SmallVector<int64_t> operandUnitDims = getUnitDims(tensorType); |
| 1461 | + if (operandUnitDims.empty()) { |
| 1462 | + LLVM_DEBUG(llvm::dbgs() << "no unit dims to drop"); |
| 1463 | + return failure(); |
| 1464 | + } |
| 1465 | + |
| 1466 | + SmallVector<Value> collapsedOperands = collapseOperands( |
| 1467 | + rewriter, llvm::to_vector(op->getOperands()), operandUnitDims); |
| 1468 | + |
| 1469 | + Type resultType = collapsedOperands[0].getType(); |
| 1470 | + Operation* collapsedOp = rewriter.create(OperationState( |
| 1471 | + op->getLoc(), op->getName().getStringRef(), collapsedOperands, |
| 1472 | + resultType, op->getAttrs(), op->getSuccessors())); |
| 1473 | + |
| 1474 | + rewriter.replaceOp( |
| 1475 | + op, expandResult(rewriter, collapsedOp->getResults()[0], |
| 1476 | + cast<RankedTensorType>(op->getResult(0).getType()), |
| 1477 | + operandUnitDims)); |
| 1478 | + return success(); |
| 1479 | + } |
| 1480 | +}; |
| 1481 | + |
1388 | 1482 | struct ConvertToCiphertextSemantics
|
1389 | 1483 | : impl::ConvertToCiphertextSemanticsBase<ConvertToCiphertextSemantics> {
|
1390 | 1484 | using ConvertToCiphertextSemanticsBase::ConvertToCiphertextSemanticsBase;
|
@@ -1425,6 +1519,10 @@ struct ConvertToCiphertextSemantics
|
1425 | 1519 | // Note ConvertAssignLayout generates tensor.concat
|
1426 | 1520 | RewritePatternSet cleanupPatterns2(context);
|
1427 | 1521 | tensor::populateDecomposeTensorConcatPatterns(cleanupPatterns2);
|
| 1522 | + // Drop unit dimensions for tensor_ext ops that require 1-D tensors (i.e. |
| 1523 | + // rotation ops) and elementwise ops. |
| 1524 | + cleanupPatterns2.add<DropRotateUnitDims, DropRotateAndReduceUnitDims, |
| 1525 | + DropElementwiseUnitDims>(context); |
1428 | 1526 | // Folding here will remove any unrealized conversion cast ops that were
|
1429 | 1527 | // inserted to persist new layouts.
|
1430 | 1528 | if (failed(applyPatternsGreedily(module, std::move(cleanupPatterns2)))) {
|
|
0 commit comments