Skip to content

Commit 9065f9b

Browse files
committed
[TSAR, Memory] Use assumption map for processOneStartOtherEndConst()
1 parent 79e0cbd commit 9065f9b

File tree

4 files changed

+186
-105
lines changed

4 files changed

+186
-105
lines changed

include/tsar/Analysis/Memory/MemoryLocationRange.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,10 @@ struct MemoryLocationRange {
236236
/// exact differences will not be saved.
237237
/// \return The result of intersection. If it is None, intersection is empty.
238238
/// If it is a location, but `Ptr` of the returned location is `nullptr`, then
239-
/// the intersection may exist but can't be calculated. Otherwise, the
240-
/// returned location is an exact intersection.
239+
/// the intersection may exist but can't be calculated (note that you will get
240+
/// the same result if the intersection is exact but LC or RC is not `nullptr`
241+
/// and we can't find the exact differences). Otherwise, the returned location
242+
/// is an exact intersection.
241243
llvm::Optional<MemoryLocationRange> intersect(
242244
MemoryLocationRange LHS,
243245
MemoryLocationRange RHS,

include/tsar/Support/SCEVUtils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ const llvm::SCEV *addSCEVAndCast(const llvm::SCEV *LHS,
8282
const llvm::SCEV *RHS,
8383
llvm::ScalarEvolution *SE);
8484

85+
/// Add 1 to the expression S.
86+
const llvm::SCEV *addOneToSCEV(const llvm::SCEV *S, llvm::ScalarEvolution *SE);
87+
88+
/// Subtract 1 from the expression S.
89+
const llvm::SCEV *subtractOneFromSCEV(const llvm::SCEV *S,
90+
llvm::ScalarEvolution *SE);
91+
8592
/// If LHS and RHS have the same sequence of type casts, subtract them and cast
8693
/// the result back to the original type.
8794
const llvm::SCEV *subtractSCEVAndCast(const llvm::SCEV *LHS,

lib/Analysis/Memory/MemoryLocationRange.cpp

Lines changed: 167 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -342,12 +342,17 @@ struct IntersectVarInfo {
342342
IntersectionResult TrueIntersection;
343343
llvm::Optional<int64_t> CmpStart;
344344
llvm::Optional<int64_t> CmpEnd;
345-
llvm::Optional<int64_t> CmpLERS;
346-
llvm::Optional<int64_t> CmpLSRE;
345+
llvm::Optional<int64_t> CmpLeftEndRightStart;
346+
llvm::Optional<int64_t> CmpLeftStartRightEnd;
347347
bool IsValidStep;
348348
};
349349

350-
// A + Const or A
350+
/// We can compare the bounds of the segments only if they can be represented
351+
/// as f(V) = V + C, where V is a variable expression and C is an integer
352+
/// constant. For the expression S representing one bound of the segment
353+
/// this function returns a pair consisting of an llvm::Value* of V and an
354+
/// integer value of C. If S does not have the form f(V), this function
355+
/// returns llvm::None.
351356
llvm::Optional<std::pair<Value *, int64_t>>
352357
parseBoundExpression(const llvm::SCEV *S) {
353358
if (auto *Unknown = dyn_cast<SCEVUnknown>(S))
@@ -375,62 +380,57 @@ parseBoundExpression(const llvm::SCEV *S) {
375380
return std::make_pair(Variable, Constant);
376381
}
377382

383+
inline std::function<Dimension& (llvm::SmallVectorImpl<MemoryLocationRange> *)>
384+
getGrowFunction(const MemoryLocationRange &LeftRange, std::size_t DimIdx) {
385+
auto Lambda = [&LeftRange, DimIdx](
386+
llvm::SmallVectorImpl<MemoryLocationRange> *C) -> Dimension & {
387+
return C->emplace_back(LeftRange).DimList[DimIdx];
388+
};
389+
return Lambda;
390+
}
391+
378392
IntersectionResult processBothVariable(const MemoryLocationRange &LeftRange,
379-
std::size_t DimIdx, const Dimension &Right, Dimension &Intersection,
380-
llvm::SmallVectorImpl<MemoryLocationRange> *LC,
381-
llvm::SmallVectorImpl<MemoryLocationRange> *RC, IntersectVarInfo &Info) {
393+
std::size_t DimIdx, const Dimension &Right, const IntersectVarInfo &Info,
394+
Dimension &Intersection, llvm::SmallVectorImpl<MemoryLocationRange> *LC,
395+
llvm::SmallVectorImpl<MemoryLocationRange> *RC) {
396+
auto Grow = getGrowFunction(LeftRange, DimIdx);
382397
auto &Left = LeftRange.DimList[DimIdx];
383398
auto SE = LeftRange.SE;
384399
auto &CmpStart = Info.CmpStart;
385400
auto &CmpEnd = Info.CmpEnd;
386-
auto &CmpLSRE = Info.CmpLSRE;
387-
auto &CmpLERS = Info.CmpLERS;
388-
if (Info.CmpLERS && *Info.CmpLERS < 0 || Info.CmpLSRE && *Info.CmpLSRE > 0)
401+
auto &CmpLSRE = Info.CmpLeftStartRightEnd;
402+
auto &CmpLERS = Info.CmpLeftEndRightStart;
403+
if (CmpLERS && *CmpLERS < 0 || CmpLSRE && *CmpLSRE > 0)
389404
return Info.EmptyIntersection;
390405
if (!CmpStart || !CmpEnd || !Info.IsValidStep)
391406
return Info.UnknownIntersection;
392-
if (*CmpStart >= 0 && *CmpEnd <= 0) {
393-
Intersection = Left;
394-
if (RC && *CmpStart > 0) {
395-
auto &Dim = RC->emplace_back(LeftRange).DimList[DimIdx];
396-
Dim.Start = Right.Start;
397-
Dim.End = subtractSCEVAndCast(Left.Start,
398-
SE->getOne(Left.Start->getType()), SE);
399-
}
400-
if (RC && *CmpEnd < 0) {
401-
auto &Dim = RC->emplace_back(LeftRange).DimList[DimIdx];
402-
Dim.Start = addSCEVAndCast(
403-
Left.End, SE->getOne(Left.End->getType()), SE);
404-
Dim.End = Right.End;
405-
}
406-
} else if (*CmpStart <= 0 && *CmpEnd >= 0 && Left.Step == Right.Step) {
407-
Intersection = Right;
408-
if (LC && *CmpStart < 0) {
409-
auto &Dim = LC->emplace_back(LeftRange).DimList[DimIdx];
410-
Dim.Start = Left.Start;
411-
Dim.End = subtractSCEVAndCast(Right.Start,
412-
SE->getOne(Right.Start->getType()), SE);
413-
}
414-
if (LC && *CmpEnd > 0) {
415-
auto &Dim = LC->emplace_back(LeftRange).DimList[DimIdx];
416-
Dim.Start = addSCEVAndCast(
417-
Right.End, SE->getOne(Right.End->getType()), SE);
418-
Dim.End = Left.End;
419-
}
420-
} else if (CmpLERS && CmpLSRE && *CmpLERS >= 0 && *CmpLSRE <= 0) {
407+
if (CmpLERS && CmpLSRE && *CmpLERS >= 0 && *CmpLSRE <= 0) {
408+
// The left dimension overlaps the right.
421409
Intersection.Start = (*CmpStart) > 0 ? Left.Start : Right.Start;
422410
Intersection.End = (*CmpEnd) < 0 ? Left.End : Right.End;
423-
if (LC && *CmpStart < 0) {
424-
auto &Dim = LC->emplace_back(LeftRange).DimList[DimIdx];
425-
Dim.Start = Left.Start;
426-
Dim.End = subtractSCEVAndCast(Right.Start,
427-
SE->getOne(Right.Start->getType()), SE);
411+
if (LC) {
412+
if (*CmpStart < 0) {
413+
auto &Dim = Grow(LC);
414+
Dim.Start = Left.Start;
415+
Dim.End = subtractOneFromSCEV(Right.Start, SE);
416+
}
417+
if (*CmpEnd > 0) {
418+
auto &Dim = Grow(LC);
419+
Dim.Start = addOneToSCEV(Right.End, SE);
420+
Dim.End = Left.End;
421+
}
428422
}
429-
if (RC && *CmpEnd < 0) {
430-
auto &Dim = RC->emplace_back(LeftRange).DimList[DimIdx];
431-
Dim.Start = addSCEVAndCast(
432-
Left.End, SE->getOne(Left.End->getType()), SE);
433-
Dim.End = Right.End;
423+
if (RC) {
424+
if (*CmpStart > 0) {
425+
auto &Dim = Grow(RC);
426+
Dim.Start = Right.Start;
427+
Dim.End = subtractOneFromSCEV(Left.Start, SE);
428+
}
429+
if (*CmpEnd < 0) {
430+
auto &Dim = Grow(RC);
431+
Dim.Start = addOneToSCEV(Left.End, SE);
432+
Dim.End = Right.End;
433+
}
434434
}
435435
} else
436436
return Info.UnknownIntersection;
@@ -442,60 +442,123 @@ IntersectionResult processOneStartOtherEndConst(
442442
const Dimension &Right, Dimension &Intersection,
443443
llvm::SmallVectorImpl<MemoryLocationRange> *LC,
444444
llvm::SmallVectorImpl<MemoryLocationRange> *RC, IntersectVarInfo &Info) {
445+
auto Grow = getGrowFunction(LeftRange, DimIdx);
445446
auto &Left = LeftRange.DimList[DimIdx];
447+
assert((isa<SCEVConstant>(Left.Start) && !isa<SCEVConstant>(Left.End) &&
448+
!isa<SCEVConstant>(Right.Start) && isa<SCEVConstant>(Right.End)) ||
449+
(isa<SCEVConstant>(Right.Start) && !isa<SCEVConstant>(Right.End) &&
450+
!isa<SCEVConstant>(Left.Start) && isa<SCEVConstant>(Left.End)));
446451
auto SE = LeftRange.SE;
447-
auto &CmpLSRE = Info.CmpLSRE;
448-
auto &CmpLERS = Info.CmpLERS;
449-
if (isa<SCEVConstant>(Left.Start)) {
450-
if (!CmpLERS)
451-
return Info.UnknownIntersection;
452-
if (*CmpLERS >= 0) {
453-
if (Info.IsValidStep) {
454-
Intersection.Start = Right.Start;
455-
Intersection.End = Left.End;
456-
if (LC) {
457-
auto &Dim = LC->emplace_back(LeftRange).DimList[DimIdx];
458-
Dim.Start = Left.Start;
459-
Dim.End = subtractSCEVAndCast(Right.Start,
460-
SE->getOne(Right.Start->getType()), SE);
461-
}
462-
if (RC) {
463-
auto &Dim = RC->emplace_back(LeftRange).DimList[DimIdx];
464-
Dim.Start = addSCEVAndCast(Left.End, SE->getOne(Left.End->getType()),
465-
SE);
466-
Dim.End = Right.End;
467-
}
468-
} else {
469-
return Info.UnknownIntersection;
452+
auto *AM = LeftRange.AM;
453+
// Let the First dimension be a dimension whose Start is constant and End
454+
// is variable. We will denote the segments as follows:
455+
// First = [m, N], Second = [P, q], where lowercase letters mean constants,
456+
// and capital letters mean variables.
457+
auto *First { &Left }, *Second { &Right };
458+
auto *FC { LC }, *SC { RC };
459+
auto &Cmpmq { Info.CmpLeftStartRightEnd },
460+
&CmpNP { Info.CmpLeftEndRightStart };
461+
if (!isa<SCEVConstant>(Left.Start)) {
462+
std::swap(First, Second);
463+
std::swap(FC, SC);
464+
std::swap(Cmpmq, CmpNP);
465+
}
466+
assert(Cmpmq && "Constant bounds must be comparable!");
467+
if (*Cmpmq > 0)
468+
return Info.EmptyIntersection;
469+
if (!CmpNP)
470+
return Info.UnknownIntersection;
471+
if (*CmpNP < 0)
472+
return Info.EmptyIntersection;
473+
if (!Info.IsValidStep)
474+
return Info.UnknownIntersection;
475+
// N >= P
476+
auto *M = First->Start, *N = First->End;
477+
auto *P = Second->Start, *Q = Second->End;
478+
auto BN = parseBoundExpression(N);
479+
auto BP = parseBoundExpression(P);
480+
if (!BP || !BN || !AM)
481+
return Info.UnknownIntersection;
482+
auto BPItr = AM->find(BP->first);
483+
auto BNItr = AM->find(BN->first);
484+
if (BPItr == AM->end() || BNItr == AM->end())
485+
return Info.UnknownIntersection;
486+
auto &BoundsP = BPItr->second;
487+
auto &BoundsN = BNItr->second;
488+
auto MInt = cast<SCEVConstant>(M)->getAPInt().getSExtValue();
489+
auto QInt = cast<SCEVConstant>(Q)->getAPInt().getSExtValue();
490+
if (!BoundsN.Lower || !BoundsN.Upper || !BoundsP.Lower || !BoundsP.Upper)
491+
return Info.UnknownIntersection;
492+
if (MInt < *BoundsP.Lower) {
493+
if (QInt > *BoundsN.Upper) {
494+
Intersection.Start = P;
495+
Intersection.End = N;
496+
if (FC) {
497+
// [m, P-1]
498+
auto &Dim = Grow(FC);
499+
Dim.Start = M;
500+
Dim.End = subtractOneFromSCEV(P, SE);
501+
}
502+
if (SC) {
503+
// [N+1, q]
504+
auto &Dim = Grow(SC);
505+
Dim.Start = addOneToSCEV(N, SE);
506+
Dim.End = Q;
507+
}
508+
} else if (QInt < *BoundsN.Lower) {
509+
Intersection = *Second;
510+
if (FC) {
511+
// [m, P-1], [q+1, N]
512+
auto &Dim1 = Grow(FC);
513+
Dim1.Start = M;
514+
Dim1.End = subtractOneFromSCEV(P, SE);
515+
auto &Dim2 = Grow(FC);
516+
Dim2.Start = addOneToSCEV(Q, SE);
517+
Dim2.End = N;
470518
}
471519
} else {
472-
return Info.EmptyIntersection;
520+
return Info.UnknownIntersection;
473521
}
474-
} else {
475-
if (!CmpLSRE)
522+
} else if (MInt <= *BoundsP.Lower && *BoundsN.Lower >= QInt) {
523+
if (FC || SC)
476524
return Info.UnknownIntersection;
477-
if (*CmpLSRE <= 0) {
478-
if (Info.IsValidStep) {
479-
Intersection.Start = Left.Start;
480-
Intersection.End = Right.End;
481-
if (LC) {
482-
auto &Dim = LC->emplace_back(LeftRange).DimList[DimIdx];
483-
Dim.Start = Right.Start;
484-
Dim.End = subtractSCEVAndCast(Left.Start,
485-
SE->getOne(Left.Start->getType()), SE);
486-
}
487-
if (RC) {
488-
auto &Dim = RC->emplace_back(LeftRange).DimList[DimIdx];
489-
Dim.Start = addSCEVAndCast(
490-
Right.End, SE->getOne(Right.End->getType()), SE);
491-
Dim.End = Left.End;
492-
}
493-
} else {
494-
return Info.UnknownIntersection;
525+
Intersection = *Second;
526+
} else if (MInt > *BoundsP.Upper) {
527+
if (QInt < BoundsN.Lower) {
528+
Intersection.Start = M;
529+
Intersection.End = Q;
530+
if (FC) {
531+
// [P, m-1]
532+
auto &Dim = Grow(FC);
533+
Dim.Start = P;
534+
Dim.End = subtractOneFromSCEV(M, SE);
535+
}
536+
if (SC) {
537+
// [q+1, N]
538+
auto &Dim = Grow(SC);
539+
Dim.Start = addOneToSCEV(Q, SE);
540+
Dim.End = N;
541+
}
542+
} else if (QInt > BoundsN.Upper) {
543+
Intersection = *First;
544+
if (SC) {
545+
// [P, m-1], [N+1, q]
546+
auto &Dim1 = Grow(SC);
547+
Dim1.Start = P;
548+
Dim1.End = subtractOneFromSCEV(M, SE);
549+
auto &Dim2 = Grow(SC);
550+
Dim2.Start = subtractOneFromSCEV(N, SE);
551+
Dim2.End = Q;
495552
}
496553
} else {
497-
return Info.EmptyIntersection;
554+
return Info.UnknownIntersection;
498555
}
556+
} else if (MInt >= *BoundsP.Upper && BoundsN.Upper <= QInt) {
557+
if (FC || SC)
558+
return Info.UnknownIntersection;
559+
Intersection = *First;
560+
} else {
561+
return Info.UnknownIntersection;
499562
}
500563
return Info.TrueIntersection;
501564
}
@@ -504,14 +567,15 @@ IntersectionResult processBothStartConst(const MemoryLocationRange &LeftRange,
504567
std::size_t DimIdx, const Dimension &Right, Dimension &Intersection,
505568
llvm::SmallVectorImpl<MemoryLocationRange> *LC,
506569
llvm::SmallVectorImpl<MemoryLocationRange> *RC, IntersectVarInfo &Info) {
507-
assert(DimIdx < LeftRange.DimList.size() && "DimIdx must match the size of LeftRange.DimList!");
570+
assert(DimIdx < LeftRange.DimList.size() &&
571+
"DimIdx must match the size of LeftRange.DimList!");
508572
auto &Left = LeftRange.DimList[DimIdx];
509573
auto SE = LeftRange.SE;
510574
auto *AM = LeftRange.AM;
511575
auto &CmpStart = Info.CmpStart;
512576
auto &CmpEnd = Info.CmpEnd;
513-
auto &CmpLSRE = Info.CmpLSRE;
514-
auto &CmpLERS = Info.CmpLERS;
577+
auto &CmpLSRE = Info.CmpLeftStartRightEnd;
578+
auto &CmpLERS = Info.CmpLeftEndRightStart;
515579
if (!CmpEnd || !Info.IsValidStep)
516580
return Info.UnknownIntersection;
517581
assert(CmpStart && "Starts of dimensions must be comparable!");
@@ -603,8 +667,8 @@ IntersectionResult processBothEndConst(const MemoryLocationRange &LeftRange,
603667
auto SE = LeftRange.SE;
604668
auto &CmpStart = Info.CmpStart;
605669
auto &CmpEnd = Info.CmpEnd;
606-
auto &CmpLSRE = Info.CmpLSRE;
607-
auto &CmpLERS = Info.CmpLERS;
670+
auto &CmpLSRE = Info.CmpLeftStartRightEnd;
671+
auto &CmpLERS = Info.CmpLeftEndRightStart;
608672
if (!CmpStart || Left.End != Right.End || !Info.IsValidStep)
609673
return Info.UnknownIntersection;
610674
Intersection.Start = *CmpStart > 0 ? Left.Start : Right.Start;
@@ -687,7 +751,7 @@ IntersectionResult processOneVariableOtherSemiconst(
687751
llvm::SmallVectorImpl<MemoryLocationRange> *RC, IntersectVarInfo &Info) {
688752
auto &Left = LeftRange.DimList[DimIdx];
689753
auto SE = LeftRange.SE;
690-
auto &CmpLSRE = Info.CmpLSRE;
754+
auto &CmpLSRE = Info.CmpLeftStartRightEnd;
691755
const Dimension *FullVar = &Left, *Semiconst = &Right;
692756
if (isa<SCEVConstant>(Left.Start) || isa<SCEVConstant>(Left.End))
693757
std::swap(FullVar, Semiconst);
@@ -772,12 +836,12 @@ IntersectionResult intersectVarDims(const MemoryLocationRange &LeftRange,
772836
cast<SCEVConstant>(Left.Step)->getAPInt().getSExtValue() == 1;
773837
Info.CmpStart = compareSCEVs(Left.Start, Right.Start, LeftRange.SE);
774838
Info.CmpEnd = compareSCEVs(Left.End, Right.End, LeftRange.SE);
775-
Info.CmpLERS = compareSCEVs(Left.End, Right.Start, LeftRange.SE);
776-
Info.CmpLSRE = compareSCEVs(Left.Start, Right.End, LeftRange.SE);
839+
Info.CmpLeftEndRightStart = compareSCEVs(Left.End, Right.Start, LeftRange.SE);
840+
Info.CmpLeftStartRightEnd = compareSCEVs(Left.Start, Right.End, LeftRange.SE);
777841
switch (PairKind) {
778842
case DimPairKind::BothVariable:
779-
return processBothVariable(LeftRange, DimIdx, Right, Intersection,
780-
LC, RC, Info);
843+
return processBothVariable(LeftRange, DimIdx, Right, Info, Intersection,
844+
LC, RC);
781845
case DimPairKind::OneStartOtherEndConst:
782846
return processOneStartOtherEndConst(LeftRange, DimIdx, Right,
783847
Intersection, LC, RC, Info);

lib/Support/SCEVUtils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,14 @@ const SCEV *addSCEVAndCast(const SCEV *LHS,
696696
return restoreCasts(TQLHS, SE->getAddExpr(InnerLHS, InnerRHS), SE);
697697
}
698698

699+
const SCEV *addOneToSCEV(const SCEV *S, ScalarEvolution *SE) {
700+
return addSCEVAndCast(S, SE->getOne(S->getType()), SE);
701+
}
702+
703+
const SCEV *subtractOneFromSCEV(const SCEV *S, ScalarEvolution *SE) {
704+
return subtractSCEVAndCast(S, SE->getOne(S->getType()), SE);
705+
}
706+
699707
const SCEV *evaluateAtIteration(const SCEVAddRecExpr *ARE, const SCEV *It,
700708
ScalarEvolution *SE) {
701709
assert(SE && "ScalarEvolution must be specified!");

0 commit comments

Comments
 (0)