SDBM: support sum expressions on the LHS of stripe expressions

Introduce support for applying the stripe operator to sum expressions, as in
  (x + A) # B = x + A - (x + A) mod B.
This is required to represent a combination of tiling and padding in the SDBM
framework, and is a valid SDBM construct that was not originally supported.

PiperOrigin-RevId: 269758807
This commit is contained in:
Alex Zinenko 2019-09-18 02:16:59 -07:00 committed by A. Unique TensorFlower
parent a15e0ce1ba
commit 5709aeb993
4 changed files with 43 additions and 19 deletions

View File

@ -77,7 +77,7 @@ class SDBMTermExpr;
/// - term
/// - symbol
/// - dimension
/// - stripe <- (term, constant)
/// - stripe <- (direct, constant)
/// - negation <- (direct)
/// - difference <- (direct, term)
/// - constant
@ -289,9 +289,9 @@ public:
return expr.getKind() == SDBMExprKind::Stripe;
}
static SDBMStripeExpr get(SDBMTermExpr var, SDBMConstantExpr stripeFactor);
static SDBMStripeExpr get(SDBMDirectExpr var, SDBMConstantExpr stripeFactor);
SDBMTermExpr getVar() const;
SDBMDirectExpr getLHS() const;
SDBMConstantExpr getStripeFactor() const;
};
@ -458,7 +458,7 @@ protected:
walk<isPreorder>(diffExpr.getLHS());
walk<isPreorder>(diffExpr.getRHS());
} else if (auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>()) {
walk<isPreorder>(stripeExpr.getVar());
walk<isPreorder>(stripeExpr.getLHS());
walk<isPreorder>(stripeExpr.getStripeFactor());
} else if (auto negExpr = expr.dyn_cast<SDBMNegExpr>()) {
walk<isPreorder>(negExpr.getVar());

View File

@ -329,7 +329,7 @@ SDBM SDBM::get(ArrayRef<SDBMExpr> inequalities, ArrayRef<SDBMExpr> equalities) {
// x - t - (C - 1) <= 0}.
for (const auto &pair : result.stripeToPoint) {
auto stripe = pair.second.cast<SDBMStripeExpr>();
SDBMBuilderResult update = builder.visit(stripe.getVar());
SDBMBuilderResult update = builder.visit(stripe.getLHS());
assert(update.negativePos.size() == 1 && update.negativePos[0] == 0 &&
"unexpected negated variable in stripe expression");
assert(update.value == 0 &&
@ -388,13 +388,13 @@ void SDBM::convertDBMElement(unsigned row, unsigned col, SDBMTermExpr rowExpr,
SDBMExpr x1Expr, int64_t value) {
if (stripeToPoint.count(x0)) {
auto stripe = stripeToPoint[x0].cast<SDBMStripeExpr>();
SDBMTermExpr var = stripe.getVar();
SDBMDirectExpr var = stripe.getLHS();
if (x1Expr == var && value >= 0)
return true;
}
if (stripeToPoint.count(x1)) {
auto stripe = stripeToPoint[x1].cast<SDBMStripeExpr>();
SDBMTermExpr var = stripe.getVar();
SDBMDirectExpr var = stripe.getLHS();
if (x0Expr == var && value >= stripe.getStripeFactor().getValue() - 1)
return true;
}

View File

@ -173,7 +173,13 @@ void SDBMExpr::print(raw_ostream &os) const {
void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
void visitStripe(SDBMStripeExpr expr) {
visitTerm(expr.getVar());
SDBMDirectExpr lhs = expr.getLHS();
bool isTerm = lhs.isa<SDBMTermExpr>();
if (!isTerm)
prn << '(';
visit(lhs);
if (!isTerm)
prn << ')';
prn << " # ";
visitConstant(expr.getStripeFactor());
}
@ -268,7 +274,7 @@ AffineExpr SDBMExpr::getAsAffineExpr() const {
}
AffineExpr visitStripe(SDBMStripeExpr expr) {
AffineExpr lhs = visit(expr.getVar()),
AffineExpr lhs = visit(expr.getLHS()),
rhs = visit(expr.getStripeFactor());
return lhs - (lhs % rhs);
}
@ -434,8 +440,6 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
// If RHS is a constant, we can always extend the SDBM expression to
// include it by sinking the constant into the nearest sum expresion.
if (auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>()) {
assert(!lhs.isa<SDBMSumExpr>() && "unexpected non-canonicalized sum");
int64_t constant = rhsConstant.getValue();
auto varying = lhs.dyn_cast<SDBMVaryingExpr>();
assert(varying && "unexpected uncanonicalized sum of constants");
@ -493,10 +497,10 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
if (!lhs || !rhs)
return {};
// 'mod' can only be converted to SDBM if its LHS is a variable
// 'mod' can only be converted to SDBM if its LHS is a direct expression
// and its RHS is a constant. Then it `x mod c = x - x stripe c`.
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
auto lhsVar = lhs.dyn_cast<SDBMTermExpr>();
auto lhsVar = lhs.dyn_cast<SDBMDirectExpr>();
if (!lhsVar || !rhsConstant)
return {};
return SDBMDiffExpr::get(lhsVar,
@ -568,7 +572,7 @@ int64_t SDBMDirectExpr::getConstant() {
// SDBMStripeExpr
//===----------------------------------------------------------------------===//
SDBMStripeExpr SDBMStripeExpr::get(SDBMTermExpr var,
SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
SDBMConstantExpr stripeFactor) {
assert(var && "expected SDBM variable expression");
assert(stripeFactor && "expected non-null stripe factor");
@ -581,9 +585,9 @@ SDBMStripeExpr SDBMStripeExpr::get(SDBMTermExpr var,
stripeFactor);
}
SDBMTermExpr SDBMStripeExpr::getVar() const {
SDBMDirectExpr SDBMStripeExpr::getLHS() const {
if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
return lhs.cast<SDBMTermExpr>();
return lhs.cast<SDBMDirectExpr>();
return {};
}
@ -738,7 +742,7 @@ SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) {
if (constantFactor.getValue() == 1)
return expr;
return SDBMStripeExpr::get(expr.cast<SDBMTermExpr>(), constantFactor);
return SDBMStripeExpr::get(expr.cast<SDBMDirectExpr>(), constantFactor);
}
} // namespace ops_assertions

View File

@ -180,7 +180,7 @@ TEST(SDBMOperators, Stripe) {
auto expr = stripe(dim(0), 3);
auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>();
ASSERT_TRUE(stripeExpr);
EXPECT_EQ(stripeExpr.getVar(), dim(0));
EXPECT_EQ(stripeExpr.getLHS(), dim(0));
EXPECT_EQ(stripeExpr.getStripeFactor().getValue(), 3);
}
@ -286,7 +286,7 @@ TEST(SDBMExpr, Stripe) {
// We can create stripe expressions and query them.
auto expr = SDBMStripeExpr::get(var, cst2);
EXPECT_EQ(expr.getVar(), var);
EXPECT_EQ(expr.getLHS(), var);
EXPECT_EQ(expr.getStripeFactor(), cst2);
// Two separately created stripe expressions with the same LHS and RHS are
@ -300,6 +300,9 @@ TEST(SDBMExpr, Stripe) {
// Non-positive stripe factors are not allowed.
EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive");
// Stripes can have sums on the LHS.
SDBMStripeExpr::get(SDBMSumExpr::get(var, cst2), cst2);
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMStripeExpr>());
@ -395,6 +398,14 @@ TEST(SDBMExpr, AffineRoundTrip) {
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(outerStripe));
// Check that ((s0 + 2) # 5) can be round-tripped through AffineExpr, i.e.
// stripe detection supports sum expressions.
auto inner = SDBMSumExpr::get(var, cst2);
auto stripeSum = SDBMStripeExpr::get(inner, cst5);
roundtripped = SDBMExpr::tryConvertAffineExpr(stripeSum.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(stripeSum));
// Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a
// deeper expression tree.
auto sum = SDBMSumExpr::get(outerStripe, cst2);
@ -402,6 +413,15 @@ TEST(SDBMExpr, AffineRoundTrip) {
roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
// Check a nested stripe-sum combination.
auto cst7 = SDBMConstantExpr::get(dialect(), 7);
auto nestedStripe =
SDBMStripeExpr::get(SDBMSumExpr::get(stripeSum, cst2), cst7);
diff = SDBMDiffExpr::get(nestedStripe, stripe);
roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
}
TEST(SDBMExpr, MatchStripeMulPattern) {