Skip to content

Commit 47b98d5

Browse files
asraacopybara-github
authored andcommitted
fix: add reassociation indices for expand_shape op
I noticed a bug when compiling the Pytorch MLP that if the reassociation map for the expand shape op that is used to canonicalize the linalg.broadcast didn't include reassociation indices. I guess it was invoking default op builders that don't require the map, and maybe other canonicalization patterns were taking precedence? Anyway, added a failing test. PiperOrigin-RevId: 798365675
1 parent 2aa29ec commit 47b98d5

File tree

7 files changed

+64
-35
lines changed

7 files changed

+64
-35
lines changed

lib/Transforms/DropUnitDims/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ cc_library(
1212
hdrs = ["DropUnitDims.h"],
1313
deps = [
1414
":pass_inc_gen",
15+
"@heir//lib/Utils:TensorUtils",
1516
"@llvm-project//llvm:Support",
1617
"@llvm-project//mlir:ArithUtils",
1718
"@llvm-project//mlir:IR",

lib/Transforms/DropUnitDims/DropUnitDims.cpp

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cstdint>
55
#include <utility>
66

7+
#include "lib/Utils/TensorUtils.h"
78
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
89
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
910
#include "llvm/include/llvm/ADT/SmallVectorExtras.h" // from @llvm-project
@@ -36,40 +37,6 @@ namespace heir {
3637

3738
namespace {
3839

39-
/// The following functions are copied from
40-
/// llvm-project/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp, where they
41-
/// are in an anonymous namespace.
42-
43-
/// Returns reassociation indices for collapsing/expanding a
44-
/// tensor of rank `rank` at positions in `positions`.
45-
static SmallVector<ReassociationIndices> getReassociationForReshapeAtDim(
46-
int64_t rank, ArrayRef<int64_t> positions) {
47-
SmallVector<ReassociationIndices> reassociation;
48-
reassociation.reserve(rank - positions.size());
49-
50-
llvm::DenseMap<int64_t, bool> positionsMap;
51-
for (int64_t pos : positions) {
52-
positionsMap[pos] = true;
53-
}
54-
auto isUnitDim = [&](int64_t dim) { return positionsMap.contains(dim); };
55-
56-
ReassociationIndices reassociationGroup;
57-
unsigned dim = 0;
58-
while (dim < rank && isUnitDim(dim)) reassociationGroup.push_back(dim++);
59-
while (dim < rank) {
60-
assert(!isUnitDim(dim) && "expected non unit-extent");
61-
reassociationGroup.push_back(dim);
62-
++dim;
63-
// Fold all following dimensions that are unit-extent.
64-
while (dim < rank && isUnitDim(dim)) {
65-
reassociationGroup.push_back(dim++);
66-
}
67-
reassociation.push_back(reassociationGroup);
68-
reassociationGroup.clear();
69-
}
70-
return reassociation;
71-
}
72-
7340
/// Collapse the given `value` so that the type matches the type of
7441
/// `origOutput`.
7542
static Value collapseValue(RewriterBase& rewriter, Location loc, Value operand,

lib/Transforms/LinalgCanonicalizations/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ cc_library(
1414
],
1515
deps = [
1616
":pass_inc_gen",
17+
"@heir//lib/Utils:TensorUtils",
1718
"@llvm-project//llvm:Support",
1819
"@llvm-project//mlir:ArithDialect",
1920
"@llvm-project//mlir:DialectUtils",

lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <cstdint>
44
#include <utility>
55

6+
#include "lib/Utils/TensorUtils.h"
67
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
78
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
89
#include "mlir/include/mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project
@@ -272,8 +273,13 @@ struct BroadcastToExpandShape
272273
broadcastOp.getInit().getType(), broadcastOp.getInput().getType());
273274
if (res != SliceVerificationResult::Success) return failure();
274275

276+
SmallVector<ReassociationIndices> expandingMap =
277+
getReassociationForReshapeAtDim(
278+
broadcastOp.getInit().getType().getRank(),
279+
broadcastOp.getDimensions());
275280
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
276-
broadcastOp, broadcastOp.getInit().getType(), broadcastOp.getInput());
281+
broadcastOp, broadcastOp.getInit().getType(), broadcastOp.getInput(),
282+
expandingMap);
277283
return success();
278284
}
279285
};

lib/Utils/TensorUtils.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#include "lib/Utils/TensorUtils.h"
22

3+
#include <cassert>
34
#include <cstdint>
45

56
#include "llvm/include/llvm/ADT/SmallVectorExtras.h" // from @llvm-project
7+
#include "mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h" // from @llvm-project
68
#include "mlir/include/mlir/Dialect/Utils/StaticValueUtils.h" // from @llvm-project
79
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
810
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
@@ -48,5 +50,33 @@ SmallVector<int64_t> getIndicesFromRowMajorShape(int64_t flattenedIndex,
4850
return indices;
4951
}
5052

53+
SmallVector<ReassociationIndices> getReassociationForReshapeAtDim(
54+
int64_t rank, ArrayRef<int64_t> positions) {
55+
SmallVector<ReassociationIndices> reassociation;
56+
reassociation.reserve(rank - positions.size());
57+
58+
llvm::DenseMap<int64_t, bool> positionsMap;
59+
for (int64_t pos : positions) {
60+
positionsMap[pos] = true;
61+
}
62+
auto isUnitDim = [&](int64_t dim) { return positionsMap.contains(dim); };
63+
64+
ReassociationIndices reassociationGroup;
65+
unsigned dim = 0;
66+
while (dim < rank && isUnitDim(dim)) reassociationGroup.push_back(dim++);
67+
while (dim < rank) {
68+
assert(!isUnitDim(dim) && "expected non unit-extent");
69+
reassociationGroup.push_back(dim);
70+
++dim;
71+
// Fold all following dimensions that are unit-extent.
72+
while (dim < rank && isUnitDim(dim)) {
73+
reassociationGroup.push_back(dim++);
74+
}
75+
reassociation.push_back(reassociationGroup);
76+
reassociationGroup.clear();
77+
}
78+
return reassociation;
79+
}
80+
5181
} // namespace heir
5282
} // namespace mlir

lib/Utils/TensorUtils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <cstdint>
55

6+
#include "mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h" // from @llvm-project
67
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
78
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
89
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
@@ -20,6 +21,15 @@ FailureOr<int64_t> getFlattenedIndex(RankedTensorType tensorType,
2021
SmallVector<int64_t> getIndicesFromRowMajorShape(int64_t flattenedIndex,
2122
SmallVector<int64_t> shape);
2223

24+
/// The following functions are copied from
25+
/// llvm-project/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp, where they
26+
/// are in an anonymous namespace.
27+
28+
/// Returns reassociation indices for collapsing/expanding a
29+
/// tensor of rank `rank` at positions in `positions`.
30+
SmallVector<ReassociationIndices> getReassociationForReshapeAtDim(
31+
int64_t rank, ArrayRef<int64_t> positions);
32+
2333
} // namespace heir
2434
} // namespace mlir
2535

tests/Transforms/linalg_canonicalizations/broadcast_to_expand.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,17 @@ module {
4242
func.return %collapsed_1 : tensor<2x3xf32>
4343
}
4444
}
45+
46+
// -----
47+
48+
module {
49+
// CHECK: func @main
50+
// CHECK-SAME: %[[arg0:.*]]: tensor<2x3xf32>
51+
// CHECK: %[[v0:.*]] = tensor.expand_shape %[[arg0]] {{\[\[}}0, 1], [2, 3]]
52+
// CHECK: return %[[v0]] : tensor<1x2x3x1xf32>
53+
func.func @main(%arg1 : tensor<2x3xf32>) -> (tensor<1x2x3x1xf32>) {
54+
%2 = tensor.empty() : tensor<1x2x3x1xf32>
55+
%broadcasted = linalg.broadcast ins(%arg1 : tensor<2x3xf32>) outs(%2 : tensor<1x2x3x1xf32>) dimensions = [0, 3]
56+
func.return %broadcasted : tensor<1x2x3x1xf32>
57+
}
58+
}

0 commit comments

Comments
 (0)