forked from OSchip/llvm-project
SDBM: support sum expressions on the LHS of stripe expressions
Introduce support for applying the stripe operator to sum expressions, as in (x + A) # B = x + A - (x + A) mod B. This is required to represent a combination of tiling and padding in the SDBM framework, and is a valid SDBM construct that was not originally supported. PiperOrigin-RevId: 269758807
This commit is contained in:
parent
a15e0ce1ba
commit
5709aeb993
|
@ -77,7 +77,7 @@ class SDBMTermExpr;
|
|||
/// - term
|
||||
/// - symbol
|
||||
/// - dimension
|
||||
/// - stripe <- (term, constant)
|
||||
/// - stripe <- (direct, constant)
|
||||
/// - negation <- (direct)
|
||||
/// - difference <- (direct, term)
|
||||
/// - constant
|
||||
|
@ -289,9 +289,9 @@ public:
|
|||
return expr.getKind() == SDBMExprKind::Stripe;
|
||||
}
|
||||
|
||||
static SDBMStripeExpr get(SDBMTermExpr var, SDBMConstantExpr stripeFactor);
|
||||
static SDBMStripeExpr get(SDBMDirectExpr var, SDBMConstantExpr stripeFactor);
|
||||
|
||||
SDBMTermExpr getVar() const;
|
||||
SDBMDirectExpr getLHS() const;
|
||||
SDBMConstantExpr getStripeFactor() const;
|
||||
};
|
||||
|
||||
|
@ -458,7 +458,7 @@ protected:
|
|||
walk<isPreorder>(diffExpr.getLHS());
|
||||
walk<isPreorder>(diffExpr.getRHS());
|
||||
} else if (auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>()) {
|
||||
walk<isPreorder>(stripeExpr.getVar());
|
||||
walk<isPreorder>(stripeExpr.getLHS());
|
||||
walk<isPreorder>(stripeExpr.getStripeFactor());
|
||||
} else if (auto negExpr = expr.dyn_cast<SDBMNegExpr>()) {
|
||||
walk<isPreorder>(negExpr.getVar());
|
||||
|
|
|
@ -329,7 +329,7 @@ SDBM SDBM::get(ArrayRef<SDBMExpr> inequalities, ArrayRef<SDBMExpr> equalities) {
|
|||
// x - t - (C - 1) <= 0}.
|
||||
for (const auto &pair : result.stripeToPoint) {
|
||||
auto stripe = pair.second.cast<SDBMStripeExpr>();
|
||||
SDBMBuilderResult update = builder.visit(stripe.getVar());
|
||||
SDBMBuilderResult update = builder.visit(stripe.getLHS());
|
||||
assert(update.negativePos.size() == 1 && update.negativePos[0] == 0 &&
|
||||
"unexpected negated variable in stripe expression");
|
||||
assert(update.value == 0 &&
|
||||
|
@ -388,13 +388,13 @@ void SDBM::convertDBMElement(unsigned row, unsigned col, SDBMTermExpr rowExpr,
|
|||
SDBMExpr x1Expr, int64_t value) {
|
||||
if (stripeToPoint.count(x0)) {
|
||||
auto stripe = stripeToPoint[x0].cast<SDBMStripeExpr>();
|
||||
SDBMTermExpr var = stripe.getVar();
|
||||
SDBMDirectExpr var = stripe.getLHS();
|
||||
if (x1Expr == var && value >= 0)
|
||||
return true;
|
||||
}
|
||||
if (stripeToPoint.count(x1)) {
|
||||
auto stripe = stripeToPoint[x1].cast<SDBMStripeExpr>();
|
||||
SDBMTermExpr var = stripe.getVar();
|
||||
SDBMDirectExpr var = stripe.getLHS();
|
||||
if (x0Expr == var && value >= stripe.getStripeFactor().getValue() - 1)
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -173,7 +173,13 @@ void SDBMExpr::print(raw_ostream &os) const {
|
|||
void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
|
||||
void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
|
||||
void visitStripe(SDBMStripeExpr expr) {
|
||||
visitTerm(expr.getVar());
|
||||
SDBMDirectExpr lhs = expr.getLHS();
|
||||
bool isTerm = lhs.isa<SDBMTermExpr>();
|
||||
if (!isTerm)
|
||||
prn << '(';
|
||||
visit(lhs);
|
||||
if (!isTerm)
|
||||
prn << ')';
|
||||
prn << " # ";
|
||||
visitConstant(expr.getStripeFactor());
|
||||
}
|
||||
|
@ -268,7 +274,7 @@ AffineExpr SDBMExpr::getAsAffineExpr() const {
|
|||
}
|
||||
|
||||
AffineExpr visitStripe(SDBMStripeExpr expr) {
|
||||
AffineExpr lhs = visit(expr.getVar()),
|
||||
AffineExpr lhs = visit(expr.getLHS()),
|
||||
rhs = visit(expr.getStripeFactor());
|
||||
return lhs - (lhs % rhs);
|
||||
}
|
||||
|
@ -434,8 +440,6 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
|||
// If RHS is a constant, we can always extend the SDBM expression to
|
||||
// include it by sinking the constant into the nearest sum expresion.
|
||||
if (auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>()) {
|
||||
assert(!lhs.isa<SDBMSumExpr>() && "unexpected non-canonicalized sum");
|
||||
|
||||
int64_t constant = rhsConstant.getValue();
|
||||
auto varying = lhs.dyn_cast<SDBMVaryingExpr>();
|
||||
assert(varying && "unexpected uncanonicalized sum of constants");
|
||||
|
@ -493,10 +497,10 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
|||
if (!lhs || !rhs)
|
||||
return {};
|
||||
|
||||
// 'mod' can only be converted to SDBM if its LHS is a variable
|
||||
// 'mod' can only be converted to SDBM if its LHS is a direct expression
|
||||
// and its RHS is a constant. Then it `x mod c = x - x stripe c`.
|
||||
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
||||
auto lhsVar = lhs.dyn_cast<SDBMTermExpr>();
|
||||
auto lhsVar = lhs.dyn_cast<SDBMDirectExpr>();
|
||||
if (!lhsVar || !rhsConstant)
|
||||
return {};
|
||||
return SDBMDiffExpr::get(lhsVar,
|
||||
|
@ -568,7 +572,7 @@ int64_t SDBMDirectExpr::getConstant() {
|
|||
// SDBMStripeExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SDBMStripeExpr SDBMStripeExpr::get(SDBMTermExpr var,
|
||||
SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
|
||||
SDBMConstantExpr stripeFactor) {
|
||||
assert(var && "expected SDBM variable expression");
|
||||
assert(stripeFactor && "expected non-null stripe factor");
|
||||
|
@ -581,9 +585,9 @@ SDBMStripeExpr SDBMStripeExpr::get(SDBMTermExpr var,
|
|||
stripeFactor);
|
||||
}
|
||||
|
||||
SDBMTermExpr SDBMStripeExpr::getVar() const {
|
||||
SDBMDirectExpr SDBMStripeExpr::getLHS() const {
|
||||
if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
|
||||
return lhs.cast<SDBMTermExpr>();
|
||||
return lhs.cast<SDBMDirectExpr>();
|
||||
return {};
|
||||
}
|
||||
|
||||
|
@ -738,7 +742,7 @@ SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) {
|
|||
if (constantFactor.getValue() == 1)
|
||||
return expr;
|
||||
|
||||
return SDBMStripeExpr::get(expr.cast<SDBMTermExpr>(), constantFactor);
|
||||
return SDBMStripeExpr::get(expr.cast<SDBMDirectExpr>(), constantFactor);
|
||||
}
|
||||
|
||||
} // namespace ops_assertions
|
||||
|
|
|
@ -180,7 +180,7 @@ TEST(SDBMOperators, Stripe) {
|
|||
auto expr = stripe(dim(0), 3);
|
||||
auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>();
|
||||
ASSERT_TRUE(stripeExpr);
|
||||
EXPECT_EQ(stripeExpr.getVar(), dim(0));
|
||||
EXPECT_EQ(stripeExpr.getLHS(), dim(0));
|
||||
EXPECT_EQ(stripeExpr.getStripeFactor().getValue(), 3);
|
||||
}
|
||||
|
||||
|
@ -286,7 +286,7 @@ TEST(SDBMExpr, Stripe) {
|
|||
|
||||
// We can create stripe expressions and query them.
|
||||
auto expr = SDBMStripeExpr::get(var, cst2);
|
||||
EXPECT_EQ(expr.getVar(), var);
|
||||
EXPECT_EQ(expr.getLHS(), var);
|
||||
EXPECT_EQ(expr.getStripeFactor(), cst2);
|
||||
|
||||
// Two separately created stripe expressions with the same LHS and RHS are
|
||||
|
@ -300,6 +300,9 @@ TEST(SDBMExpr, Stripe) {
|
|||
// Non-positive stripe factors are not allowed.
|
||||
EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive");
|
||||
|
||||
// Stripes can have sums on the LHS.
|
||||
SDBMStripeExpr::get(SDBMSumExpr::get(var, cst2), cst2);
|
||||
|
||||
// Hierarchy is okay.
|
||||
auto generic = static_cast<SDBMExpr>(expr);
|
||||
EXPECT_TRUE(generic.isa<SDBMStripeExpr>());
|
||||
|
@ -395,6 +398,14 @@ TEST(SDBMExpr, AffineRoundTrip) {
|
|||
ASSERT_TRUE(roundtripped.hasValue());
|
||||
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(outerStripe));
|
||||
|
||||
// Check that ((s0 + 2) # 5) can be round-tripped through AffineExpr, i.e.
|
||||
// stripe detection supports sum expressions.
|
||||
auto inner = SDBMSumExpr::get(var, cst2);
|
||||
auto stripeSum = SDBMStripeExpr::get(inner, cst5);
|
||||
roundtripped = SDBMExpr::tryConvertAffineExpr(stripeSum.getAsAffineExpr());
|
||||
ASSERT_TRUE(roundtripped.hasValue());
|
||||
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(stripeSum));
|
||||
|
||||
// Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a
|
||||
// deeper expression tree.
|
||||
auto sum = SDBMSumExpr::get(outerStripe, cst2);
|
||||
|
@ -402,6 +413,15 @@ TEST(SDBMExpr, AffineRoundTrip) {
|
|||
roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
|
||||
ASSERT_TRUE(roundtripped.hasValue());
|
||||
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
|
||||
|
||||
// Check a nested stripe-sum combination.
|
||||
auto cst7 = SDBMConstantExpr::get(dialect(), 7);
|
||||
auto nestedStripe =
|
||||
SDBMStripeExpr::get(SDBMSumExpr::get(stripeSum, cst2), cst7);
|
||||
diff = SDBMDiffExpr::get(nestedStripe, stripe);
|
||||
roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
|
||||
ASSERT_TRUE(roundtripped.hasValue());
|
||||
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
|
||||
}
|
||||
|
||||
TEST(SDBMExpr, MatchStripeMulPattern) {
|
||||
|
|
Loading…
Reference in New Issue