Skip to content

Commit 00832e9

Browse files
Merge pull request #2163 from j2kun:rotate-and-reduce-indivisible
PiperOrigin-RevId: 800611385
2 parents 41614ed + 3866f51 commit 00832e9

File tree

4 files changed

+61
-15
lines changed

4 files changed

+61
-15
lines changed

lib/Dialect/TensorExt/Transforms/ImplementRotateAndReduce.cpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,6 @@ LogicalResult convertRotateAndReduceOp(RotateAndReduceOp op) {
4949
}
5050

5151
TypedValue<RankedTensorType> plaintexts = op.getPlaintexts();
52-
53-
// Validate divisibility of step size
54-
auto babySteps = static_cast<int64_t>(std::floor(std::sqrt(steps)));
55-
unsigned giantSteps = steps / babySteps;
56-
if (giantSteps * babySteps != steps) {
57-
return op.emitOpError()
58-
<< "requires steps to be a multiple of sqrt(steps), but found "
59-
"steps="
60-
<< steps << " and babySteps=" << babySteps;
61-
}
62-
LLVM_DEBUG(llvm::dbgs()
63-
<< "Using baby-step / giant-step decomposition of sum of size "
64-
<< steps << " with babySteps= " << babySteps
65-
<< " and giantSteps= " << giantSteps << "\n");
66-
6752
auto plaintextsLeaf = std::optional<SSAValue>(plaintexts);
6853
implementedKernel =
6954
implementRotateAndReduce(vectorLeaf, plaintextsLeaf, period, steps);

lib/Kernel/KernelImplementation.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,26 @@ implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
149149

150150
// Use a value of sqrt(n) as the baby step / giant step size.
151151
int64_t numBabySteps = static_cast<int64_t>(std::floor(std::sqrt(steps)));
152+
if (steps % numBabySteps != 0) {
153+
// Find the nearest divisible number to use for baby step
154+
// TODO(#2162): determine the right tradeoff here
155+
int lower = numBabySteps;
156+
int upper = numBabySteps;
157+
158+
while (steps % lower != 0 && steps % upper != steps) {
159+
lower--;
160+
upper++;
161+
}
162+
163+
if (steps % lower == 0 && lower > 1) {
164+
numBabySteps = lower;
165+
} else if (steps % upper == 0) {
166+
numBabySteps = upper;
167+
} else {
168+
numBabySteps = steps;
169+
}
170+
}
171+
152172
int64_t giantStepSize = numBabySteps;
153173
int64_t numGiantSteps = steps / numBabySteps;
154174

lib/Kernel/RotateAndReduceImplTest.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,44 @@ TEST(RotateAndReduceImplTest, TestPeriod2WithNoPlaintext) {
153153
EXPECT_EQ(expected, actual);
154154
}
155155

156+
TEST(RotateAndReduceImplTest, TestUnitPeriodWithIndivisibleN10) {
157+
int64_t n = 10;
158+
int64_t period = 1;
159+
160+
std::vector<int> vector = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
161+
std::vector<std::vector<int>> plaintexts = {
162+
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
163+
{2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
164+
{3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
165+
{4, 5, 6, 7, 8, 9, 10, 11, 12, 13},
166+
{5, 6, 7, 8, 9, 10, 11, 12, 13, 14},
167+
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
168+
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
169+
{8, 9, 10, 11, 12, 13, 14, 15, 16, 17},
170+
{9, 10, 11, 12, 13, 14, 15, 16, 17, 18},
171+
{10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
172+
};
173+
174+
std::vector<int> expected = runNaive(vector, plaintexts, period, n);
175+
std::vector<int> actual = runImpl(vector, plaintexts, period, n);
176+
EXPECT_EQ(expected, actual);
177+
}
178+
179+
TEST(RotateAndReduceImplTest, TestUnitPeriodWithIndivisibleN6) {
180+
int64_t n = 6;
181+
int64_t period = 1;
182+
183+
std::vector<int> vector = {0, 1, 2, 3, 4, 5};
184+
std::vector<std::vector<int>> plaintexts = {
185+
{1, 2, 3, 4, 5, 6}, {2, 3, 4, 5, 6, 7}, {3, 4, 5, 6, 7, 8},
186+
{4, 5, 6, 7, 8, 9}, {5, 6, 7, 8, 9, 10}, {6, 7, 8, 9, 10, 11},
187+
};
188+
189+
std::vector<int> expected = runNaive(vector, plaintexts, period, n);
190+
std::vector<int> actual = runImpl(vector, plaintexts, period, n);
191+
EXPECT_EQ(expected, actual);
192+
}
193+
156194
} // namespace
157195
} // namespace kernel
158196
} // namespace heir

lib/Kernel/TestingUtils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ LiteralValue EvalVisitor::operator()(const LeftRotateNode<LiteralValue>& node) {
8686
auto operand = this->process(node.operand);
8787
auto dim = operand.getShape()[0];
8888
int amount = node.shift;
89+
// Normalize amount to be in [0, dim)
90+
amount = ((amount % dim) + dim) % dim;
91+
8992
const auto& oVal = operand.getTensor();
9093
const auto* oVec = std::get_if<std::vector<int>>(&oVal);
9194
assert(oVec && "unsupported rotate operand");

0 commit comments

Comments
 (0)