Skip to content

Commit d168e37

Browse files
asraacopybara-github
authored andcommitted
Drop unit dimensions in tensor_ext.rotate and tensor_ext.rotate_and_reduce.
This allows rotation operations to be run on 1-D tensors. PiperOrigin-RevId: 800635444
1 parent 00832e9 commit d168e37

File tree

6 files changed

+194
-39
lines changed

6 files changed

+194
-39
lines changed

lib/Transforms/ConvertToCiphertextSemantics/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ cc_library(
1616
":pass_inc_gen",
1717
"@heir//lib/Dialect/Secret/IR:SecretPatterns",
1818
"@heir//lib/Dialect/TensorExt/IR:Dialect",
19+
"@heir//lib/Transforms/DropUnitDims",
1920
"@heir//lib/Utils",
2021
"@heir//lib/Utils:AffineMapUtils",
2122
"@heir//lib/Utils:AttributeUtils",

lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "lib/Dialect/TensorExt/IR/TensorExtOps.h"
1717
#include "lib/Transforms/ConvertToCiphertextSemantics/AssignLayout.h"
1818
#include "lib/Transforms/ConvertToCiphertextSemantics/TypeConversion.h"
19+
#include "lib/Transforms/DropUnitDims/DropUnitDims.h"
1920
#include "lib/Utils/AffineMapUtils.h"
2021
#include "lib/Utils/AttributeUtils.h"
2122
#include "lib/Utils/ContextAwareConversionUtils.h"
@@ -27,6 +28,7 @@
2728
#include "lib/Utils/Utils.h"
2829
#include "llvm/include/llvm/ADT/ArrayRef.h" // from @llvm-project
2930
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
31+
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
3032
#include "llvm/include/llvm/ADT/StringExtras.h" // from @llvm-project
3133
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
3234
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
@@ -47,13 +49,13 @@
4749
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
4850
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
4951
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
52+
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
5053
#include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project
5154
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
5255
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
5356
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
5457
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
5558
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
56-
#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project
5759

5860
#define DEBUG_TYPE "convert-to-ciphertext-semantics"
5961

@@ -1357,7 +1359,6 @@ class ConvertExpandShape
13571359
if (!sourceLayout) {
13581360
return op.emitError() << "failed to fetch new layout attribute for input";
13591361
}
1360-
op.dump();
13611362

13621363
if (resultType != srcType) {
13631364
return rewriter.notifyMatchFailure(
@@ -1385,6 +1386,99 @@ class ConvertExpandShape
13851386
}
13861387
};
13871388

1389+
struct DropRotateUnitDims : OpRewritePattern<tensor_ext::RotateOp> {
1390+
using OpRewritePattern<tensor_ext::RotateOp>::OpRewritePattern;
1391+
1392+
LogicalResult matchAndRewrite(tensor_ext::RotateOp rotateOp,
1393+
PatternRewriter& rewriter) const override {
1394+
SmallVector<int64_t> operandUnitDims =
1395+
getUnitDims(rotateOp.getTensor().getType());
1396+
if (operandUnitDims.empty()) {
1397+
LLVM_DEBUG(llvm::dbgs() << "no unit dims to drop");
1398+
return failure();
1399+
}
1400+
1401+
SmallVector<Value> collapsedOperands =
1402+
collapseOperands(rewriter, {rotateOp.getTensor()}, operandUnitDims);
1403+
1404+
tensor_ext::RotateOp collapsedOp = tensor_ext::RotateOp::create(
1405+
rewriter, rotateOp.getLoc(), collapsedOperands[0], rotateOp.getShift());
1406+
rewriter.replaceOp(rotateOp, expandResult(rewriter, collapsedOp.getResult(),
1407+
rotateOp.getOutput().getType(),
1408+
operandUnitDims));
1409+
return success();
1410+
}
1411+
};
1412+
1413+
struct DropRotateAndReduceUnitDims
1414+
: OpRewritePattern<tensor_ext::RotateAndReduceOp> {
1415+
using OpRewritePattern<tensor_ext::RotateAndReduceOp>::OpRewritePattern;
1416+
1417+
LogicalResult matchAndRewrite(tensor_ext::RotateAndReduceOp rotateOp,
1418+
PatternRewriter& rewriter) const override {
1419+
SmallVector<int64_t> operandUnitDims =
1420+
getUnitDims(rotateOp.getTensor().getType());
1421+
if (operandUnitDims.empty()) {
1422+
LLVM_DEBUG(llvm::dbgs() << "no unit dims to drop");
1423+
return failure();
1424+
}
1425+
1426+
SmallVector<Value> collapsedOperands =
1427+
collapseOperands(rewriter, {rotateOp.getTensor()}, operandUnitDims);
1428+
1429+
auto collapsedOp = tensor_ext::RotateAndReduceOp::create(
1430+
rewriter, rotateOp.getLoc(), collapsedOperands[0],
1431+
rotateOp.getPlaintexts(), rotateOp.getPeriod(), rotateOp.getSteps());
1432+
rewriter.replaceOp(rotateOp, expandResult(rewriter, collapsedOp.getResult(),
1433+
rotateOp.getOutput().getType(),
1434+
operandUnitDims));
1435+
return success();
1436+
}
1437+
};
1438+
1439+
struct DropElementwiseUnitDims : OpTraitRewritePattern<OpTrait::Elementwise> {
1440+
explicit DropElementwiseUnitDims(MLIRContext* context)
1441+
: OpTraitRewritePattern(context) {}
1442+
1443+
LogicalResult matchAndRewrite(mlir::Operation* op,
1444+
PatternRewriter& rewriter) const override {
1445+
// Ensure that all operands and results have the same type.
1446+
SmallVector<Type> operandAndResultTypes =
1447+
llvm::to_vector(op->getOperandTypes());
1448+
operandAndResultTypes.append(op->getResultTypes().begin(),
1449+
op->getResultTypes().end());
1450+
if (!llvm::all_equal(operandAndResultTypes) || op->getNumOperands() == 0 ||
1451+
op->getNumResults() != 1) {
1452+
return failure();
1453+
}
1454+
1455+
auto tensorType = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1456+
if (!tensorType) {
1457+
return failure();
1458+
}
1459+
1460+
SmallVector<int64_t> operandUnitDims = getUnitDims(tensorType);
1461+
if (operandUnitDims.empty()) {
1462+
LLVM_DEBUG(llvm::dbgs() << "no unit dims to drop");
1463+
return failure();
1464+
}
1465+
1466+
SmallVector<Value> collapsedOperands = collapseOperands(
1467+
rewriter, llvm::to_vector(op->getOperands()), operandUnitDims);
1468+
1469+
Type resultType = collapsedOperands[0].getType();
1470+
Operation* collapsedOp = rewriter.create(OperationState(
1471+
op->getLoc(), op->getName().getStringRef(), collapsedOperands,
1472+
resultType, op->getAttrs(), op->getSuccessors()));
1473+
1474+
rewriter.replaceOp(
1475+
op, expandResult(rewriter, collapsedOp->getResults()[0],
1476+
cast<RankedTensorType>(op->getResult(0).getType()),
1477+
operandUnitDims));
1478+
return success();
1479+
}
1480+
};
1481+
13881482
struct ConvertToCiphertextSemantics
13891483
: impl::ConvertToCiphertextSemanticsBase<ConvertToCiphertextSemantics> {
13901484
using ConvertToCiphertextSemanticsBase::ConvertToCiphertextSemanticsBase;
@@ -1425,6 +1519,10 @@ struct ConvertToCiphertextSemantics
14251519
// Note ConvertAssignLayout generates tensor.concat
14261520
RewritePatternSet cleanupPatterns2(context);
14271521
tensor::populateDecomposeTensorConcatPatterns(cleanupPatterns2);
1522+
// Drop unit dimensions for tensor_ext ops that require 1-D tensors (i.e.
1523+
// rotation ops) and elementwise ops.
1524+
cleanupPatterns2.add<DropRotateUnitDims, DropRotateAndReduceUnitDims,
1525+
DropElementwiseUnitDims>(context);
14281526
// Folding here will remove any unrealized conversion cast ops that were
14291527
// inserted to persist new layouts.
14301528
if (failed(applyPatternsGreedily(module, std::move(cleanupPatterns2)))) {

lib/Transforms/DropUnitDims/DropUnitDims.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,30 @@ namespace heir {
3737

3838
namespace {
3939

40-
/// Collapse the given `value` so that the type matches the type of
41-
/// `origOutput`.
42-
static Value collapseValue(RewriterBase& rewriter, Location loc, Value operand,
43-
ArrayRef<int64_t> targetShape,
44-
ArrayRef<ReassociationIndices> reassociation) {
40+
Value collapseValue(RewriterBase& rewriter, Location loc, Value operand,
41+
ArrayRef<int64_t> targetShape,
42+
ArrayRef<ReassociationIndices> reassociation) {
4543
auto tensorType = cast<RankedTensorType>(operand.getType());
4644
auto targetType =
4745
RankedTensorType::get(targetShape, tensorType.getElementType());
4846
return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand,
4947
reassociation);
5048
}
5149

52-
/// Returns a collapsed `val` where the collapsing occurs at dims in positions.
53-
static Value collapseDimsAt(PatternRewriter& rewriter, Value val,
54-
ArrayRef<int64_t> positions) {
50+
} // namespace
51+
52+
SmallVector<int64_t> getUnitDims(ShapedType type) {
53+
SmallVector<int64_t> unitDims;
54+
for (int64_t i = 0; i < type.getRank(); ++i) {
55+
if (type.getDimSize(i) == 1) {
56+
unitDims.push_back(i);
57+
}
58+
}
59+
return unitDims;
60+
}
61+
62+
Value collapseDimsAt(PatternRewriter& rewriter, Value val,
63+
ArrayRef<int64_t> positions) {
5564
auto valType = cast<ShapedType>(val.getType());
5665
SmallVector<int64_t> collapsedShape(valType.getShape());
5766
for (int64_t pos : llvm::reverse(positions)) {
@@ -62,7 +71,22 @@ static Value collapseDimsAt(PatternRewriter& rewriter, Value val,
6271
getReassociationForReshapeAtDim(valType.getRank(), positions));
6372
}
6473

65-
} // namespace
74+
/// Collapse all collapsible operands.
75+
SmallVector<Value> collapseOperands(PatternRewriter& rewriter,
76+
ArrayRef<Value> operands,
77+
ArrayRef<int64_t> collapseDims) {
78+
return llvm::map_to_vector(operands, [&](auto operand) {
79+
return collapseDimsAt(rewriter, operand, collapseDims);
80+
});
81+
}
82+
83+
/// Expand result tensor.
84+
Value expandResult(PatternRewriter& rewriter, Value result,
85+
RankedTensorType expandedType, SmallVector<int64_t> dims) {
86+
return tensor::ExpandShapeOp::create(
87+
rewriter, result.getLoc(), expandedType, result,
88+
getReassociationForReshapeAtDim(expandedType.getRank(), dims));
89+
}
6690

6791
// Drop unit dims on linalg.map operations that perform a single elementwise
6892
// operation. This will only drop batch dims (leading unit dimensions). This
@@ -72,24 +96,6 @@ static Value collapseDimsAt(PatternRewriter& rewriter, Value val,
7296
struct ReduceLinalgMap : OpRewritePattern<linalg::MapOp> {
7397
using OpRewritePattern<linalg::MapOp>::OpRewritePattern;
7498

75-
/// Collapse all collapsible operands.
76-
SmallVector<Value> collapseOperands(PatternRewriter& rewriter,
77-
ArrayRef<Value> operands,
78-
ArrayRef<int64_t> collapseDims) const {
79-
return llvm::map_to_vector(operands, [&](auto operand) {
80-
return collapseDimsAt(rewriter, operand, collapseDims);
81-
});
82-
}
83-
84-
/// Expand result tensor.
85-
Value expandResult(PatternRewriter& rewriter, Value result,
86-
RankedTensorType expandedType,
87-
SmallVector<int64_t> dims) const {
88-
return tensor::ExpandShapeOp::create(
89-
rewriter, result.getLoc(), expandedType, result,
90-
getReassociationForReshapeAtDim(expandedType.getRank(), dims));
91-
}
92-
9399
LogicalResult matchAndRewrite(linalg::MapOp mapOp,
94100
PatternRewriter& rewriter) const override {
95101
if (mapOp.hasUserDefinedMaps()) {
@@ -114,14 +120,8 @@ struct ReduceLinalgMap : OpRewritePattern<linalg::MapOp> {
114120

115121
// Check for unit dims in the output shape. A map op requires all inputs and
116122
// outputs have the same shape.
117-
auto outputShape = mapOp.getInit().getType().getShape();
118-
SmallVector<int64_t> operandUnitDims;
119-
for (int64_t i = 0; i < outputShape.size(); ++i) {
120-
if (outputShape[i] == 1) {
121-
operandUnitDims.push_back(i);
122-
}
123-
}
124-
123+
SmallVector<int64_t> operandUnitDims =
124+
getUnitDims(mapOp.getInit().getType());
125125
if (operandUnitDims.empty()) {
126126
LLVM_DEBUG(llvm::dbgs() << "no unit dims to drop");
127127
return failure();

lib/Transforms/DropUnitDims/DropUnitDims.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
#ifndef LIB_TRANSFORMS_DROPUNITDIMS_DROPUNITDIMS_H_
22
#define LIB_TRANSFORMS_DROPUNITDIMS_DROPUNITDIMS_H_
33

4-
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
4+
#include "mlir/include/mlir/Dialect/Arith/Utils/Utils.h" // from @llvm-project
5+
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
6+
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
7+
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
8+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
59

610
namespace mlir {
711
namespace heir {
@@ -12,6 +16,22 @@ namespace heir {
1216
#define GEN_PASS_REGISTRATION
1317
#include "lib/Transforms/DropUnitDims/DropUnitDims.h.inc"
1418

19+
// Returns a list of unit dims of a type
20+
SmallVector<int64_t> getUnitDims(ShapedType type);
21+
22+
/// Returns a collapsed `val` where the collapsing occurs at dims in positions.
23+
Value collapseDimsAt(PatternRewriter& rewriter, Value val,
24+
ArrayRef<int64_t> positions);
25+
26+
/// Collapse all collapsible operands.
27+
SmallVector<Value> collapseOperands(PatternRewriter& rewriter,
28+
ArrayRef<Value> operands,
29+
ArrayRef<int64_t> collapseDims);
30+
31+
/// Expand result tensor.
32+
Value expandResult(PatternRewriter& rewriter, Value result,
33+
RankedTensorType expandedType, SmallVector<int64_t> dims);
34+
1535
} // namespace heir
1636
} // namespace mlir
1737

tests/Transforms/convert_to_ciphertext_semantics/collapse_shape.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ module {
3434
%1 = tensor_ext.assign_layout %cst {layout = #new_layout2, tensor_ext.layout = #new_layout2} : tensor<512xf32>
3535
// CHECK: secret.generic
3636
// CHECK-NEXT: ^body(%[[input0:.*]]: tensor<1x1024xf32>)
37-
// CHECK: tensor_ext.rotate_and_reduce %[[input0]]
37+
// CHECK: %[[collapsed:.*]] = tensor.collapse_shape %[[input0]]
38+
// CHECK: tensor_ext.rotate_and_reduce %[[collapsed]]
3839
%7 = secret.generic(%arg4: !secret.secret<tensor<1x784xf32>> {tensor_ext.layout = #new_layout5}) {
3940
^body(%input0: tensor<1x784xf32>):
4041
%collapsed = tensor.collapse_shape %input0 [[0, 1]] {tensor_ext.layout = #new_layout6} : tensor<1x784xf32> into tensor<784xf32>
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: heir-opt %s --convert-to-ciphertext-semantics --split-input-file | FileCheck %s
2+
3+
#kernel = #secret.kernel<name = "MatvecDiagonal", force = false>
4+
#new_layout1 = #tensor_ext.new_layout<"{ [i0, i1] -> [ct, slot] : (i0 - i1 + ct) mod 512 = 0 and (-i1 + ct + slot) mod 1024 = 0 and 0 <= i0 <= 511 and 0 <= i1 <= 783 and 0 <= ct <= 511 and 0 <= slot <= 1023 }">
5+
#new_layout2 = #tensor_ext.new_layout<"{ [i0] -> [ct, slot] : ct = 0 and (-i0 + slot) mod 512 = 0 and 0 <= i0 <= 511 and 0 <= slot <= 1023 }">
6+
#new_layout5 = #tensor_ext.new_layout<"{ [i0, i1] -> [ct, slot] : i0 = 0 and ct = 0 and (-i1 + slot) mod 1024 = 0 and 0 <= i1 <= 783 and 0 <= slot <= 1023 }">
7+
#new_layout6 = #tensor_ext.new_layout<"{ [i0] -> [ct, slot] : ct = 0 and (-i0 + slot) mod 1024 = 0 and 0 <= i0 <= 783 and 0 <= slot <= 1023 }">
8+
module{
9+
// CHECK: func.func @main
10+
func.func @main(%arg0: tensor<512x784xf32>, %arg1: tensor<512xf32>, %arg4: !secret.secret<tensor<1x784xf32>> {tensor_ext.layout = #new_layout5}) -> (!secret.secret<tensor<512xf32>> {jax.result_info = "result[0]", tensor_ext.layout = #new_layout2}) {
11+
%cst = arith.constant dense<0.000000e+00> : tensor<512xf32>
12+
%0 = tensor_ext.assign_layout %arg0 {layout = #new_layout1, tensor_ext.layout = #new_layout1} : tensor<512x784xf32>
13+
%1 = tensor_ext.assign_layout %cst {layout = #new_layout2, tensor_ext.layout = #new_layout2} : tensor<512xf32>
14+
%2 = tensor_ext.assign_layout %arg1 {layout = #new_layout2, tensor_ext.layout = #new_layout2} : tensor<512xf32>
15+
// CHECK: secret.generic(%[[arg2:.*]]: !secret.secret<tensor<1x1024xf32>>)
16+
// CHECK: ^body(%[[input0:.*]]: tensor<1x1024xf32>)
17+
// CHECK: %[[collapsed:.*]] = tensor.collapse_shape %[[input0]]
18+
// CHECK-SAME: tensor<1x1024xf32> into tensor<1024xf32>
19+
// CHECK: %[[v4:.*]] = tensor_ext.rotate_and_reduce %[[collapsed]]
20+
// CHECK: %[[collapsed_2:.*]] = tensor.collapse_shape
21+
// CHECK: %[[v5:.*]] = arith.addf %[[v4]], %[[collapsed_2]] : tensor<1024xf32>
22+
// CHECK: %[[v6:.*]] = tensor_ext.rotate %[[v5]], %[[c512:.*]] : tensor<1024xf32>
23+
// CHECK: %[[expanded:.*]] = tensor.expand_shape
24+
// CHECK-SAME: tensor<1024xf32> into tensor<1x1024xf32>
25+
// CHECK: secret.yield %[[expanded]]
26+
%7 = secret.generic(%arg4: !secret.secret<tensor<1x784xf32>> {tensor_ext.layout = #new_layout5}) {
27+
^body(%input0: tensor<1x784xf32>):
28+
%collapsed = tensor.collapse_shape %input0 [[0, 1]] {tensor_ext.layout = #new_layout6} : tensor<1x784xf32> into tensor<784xf32>
29+
%8 = linalg.matvec {secret.kernel = #kernel, tensor_ext.layout = #new_layout2} ins(%0, %collapsed : tensor<512x784xf32>, tensor<784xf32>) outs(%1 : tensor<512xf32>) -> tensor<512xf32>
30+
%9 = arith.addf %2, %8 {tensor_ext.layout = #new_layout2} : tensor<512xf32>
31+
secret.yield %9 : tensor<512xf32>
32+
} -> (!secret.secret<tensor<512xf32>> {tensor_ext.layout = #new_layout2})
33+
return %7 : !secret.secret<tensor<512xf32>>
34+
}
35+
}

0 commit comments

Comments
 (0)