Simplify SDBM expressions more aggressively in operators and conversions

Extend SDBM simplification patterns to support more cases where the addition of
two expressions each involving one or two variables would result in a sum
expression that only contains one variable and thus remains in the SDBM domain.
This is made possible by the new canonical structure of SDBM where the constant
term appears once.  This simplification will be necessary to support
round-tripping of stripe expressions containing constant terms on the LHS
through affine expressions.

PiperOrigin-RevId: 269757732
This commit is contained in:
Alex Zinenko 2019-09-18 02:08:19 -07:00 committed by A. Unique TensorFlower
parent b58d9aee11
commit a15e0ce1ba
3 changed files with 74 additions and 32 deletions

View File

@ -47,6 +47,7 @@ class SDBMConstantExpr;
class SDBMDialect;
class SDBMDimExpr;
class SDBMSymbolExpr;
class SDBMTermExpr;
/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side
/// expression for the SDBM framework. SDBM expressions are a subset of affine
@ -206,6 +207,13 @@ class SDBMDirectExpr : public SDBMVaryingExpr {
public:
using SDBMVaryingExpr::SDBMVaryingExpr;
/// If this is a sum expression, return its variable part, otherwise return
/// self.
SDBMTermExpr getTerm();
/// If this is a sum expression, return its constant part, otherwise return 0.
int64_t getConstant();
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId ||
expr.getKind() == SDBMExprKind::SymbolId ||

View File

@ -354,32 +354,16 @@ static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) {
// Build a difference expression given a direct expression and a negation
// expression.
static SDBMExpr buildDiffExpr(SDBMDirectExpr lhs, SDBMNegExpr rhs) {
SDBMTermExpr lhsTerm, rhsTerm;
int lhsConstant = 0;
int64_t rhsConstant = 0;
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
lhsConstant = lhsSum.getRHS().getValue();
lhsTerm = lhsSum.getLHS();
} else {
lhsTerm = lhs.cast<SDBMTermExpr>();
}
if (auto rhsNegatedSum = rhs.getVar().dyn_cast<SDBMSumExpr>()) {
rhsTerm = rhsNegatedSum.getLHS();
rhsConstant = rhsNegatedSum.getRHS().getValue();
} else {
rhsTerm = rhs.getVar().cast<SDBMTermExpr>();
}
// Fold (x + C) - (x + D) = C - D.
if (lhsTerm == rhsTerm)
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant - rhsConstant);
if (lhs.getTerm() == rhs.getVar().getTerm())
return SDBMConstantExpr::get(
lhs.getDialect(), lhs.getConstant() - rhs.getVar().getConstant());
return SDBMDiffExpr::get(
addConstantAndSink<SDBMDirectExpr>(lhs, -rhsConstant, /*negated=*/false,
addConstantAndSink<SDBMDirectExpr>(lhs, -rhs.getVar().getConstant(),
/*negated=*/false,
[](SDBMDirectExpr e) { return e; }),
rhsTerm);
rhs.getVar().getTerm());
}
// Try folding an expression (lhs + rhs) where at least one of the operands
@ -400,18 +384,38 @@ static SDBMExpr foldSumDiff(SDBMExpr lhs, SDBMExpr rhs) {
// If a subexpression appears in a diff expression on the LHS(RHS) of a
// sum expression where it also appears on the RHS(LHS) with the opposite
// sign, we can simplify it away and obtain the SDBM form.
// x - (x - C) = -(x - C) + x = C
// (x - C) - x = -x + (x - C) = -C
auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
if (lhsNeg && rhsDiff && lhsNeg.getVar() == rhsDiff.getLHS())
return SDBMNegExpr::get(rhsDiff.getRHS());
if (lhsDirect && rhsDiff && lhsDirect == rhsDiff.getRHS())
return rhsDiff.getLHS();
if (lhsDiff && rhsNeg && lhsDiff.getLHS() == rhsNeg.getVar())
return SDBMNegExpr::get(lhsDiff.getRHS());
if (rhsDirect && lhsDiff && rhsDirect == lhsDiff.getRHS())
return lhsDiff.getLHS();
// -(x + A) + ((x + B) - y) = -(y + (A - B))
if (lhsNeg && rhsDiff &&
lhsNeg.getVar().getTerm() == rhsDiff.getLHS().getTerm()) {
int64_t constant =
lhsNeg.getVar().getConstant() - rhsDiff.getLHS().getConstant();
// RHS of the diff is a term expression, its sum with a constant is a direct
// expression.
return SDBMNegExpr::get(
addConstant(rhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
}
// (x + A) + ((y + B) - x) = (y + B) + A.
if (lhsDirect && rhsDiff && lhsDirect.getTerm() == rhsDiff.getRHS())
return addConstant(rhsDiff.getLHS(), lhsDirect.getConstant());
// ((x + A) - y) + (-(x + B)) = -(y + (B - A)).
if (lhsDiff && rhsNeg &&
lhsDiff.getLHS().getTerm() == rhsNeg.getVar().getTerm()) {
int64_t constant =
rhsNeg.getVar().getConstant() - lhsDiff.getLHS().getConstant();
// RHS of the diff is a term expression, its sum with a constant is a direct
// expression.
return SDBMNegExpr::get(
addConstant(lhsDiff.getRHS(), constant).cast<SDBMDirectExpr>());
}
// ((x + A) - y) + (y + B) = (x + A) + B.
if (rhsDirect && lhsDiff && rhsDirect.getTerm() == lhsDiff.getRHS())
return addConstant(lhsDiff.getLHS(), rhsDirect.getConstant());
return {};
}
@ -544,6 +548,22 @@ SDBMTermExpr SDBMDiffExpr::getRHS() const {
return static_cast<ImplType *>(impl)->rhs;
}
//===----------------------------------------------------------------------===//
// SDBMDirectExpr
//===----------------------------------------------------------------------===//
SDBMTermExpr SDBMDirectExpr::getTerm() {
if (auto sum = dyn_cast<SDBMSumExpr>())
return sum.getLHS();
return cast<SDBMTermExpr>();
}
int64_t SDBMDirectExpr::getConstant() {
if (auto sum = dyn_cast<SDBMSumExpr>())
return sum.getRHS().getValue();
return 0;
}
//===----------------------------------------------------------------------===//
// SDBMStripeExpr
//===----------------------------------------------------------------------===//

View File

@ -100,6 +100,20 @@ TEST(SDBMOperators, AddFolding) {
EXPECT_EQ(diffOfDiffs.getValue(), 0);
}
TEST(SDBMOperators, AddNegativeTerms) {
const int64_t A = 7;
const int64_t B = -5;
auto x = SDBMDimExpr::get(dialect(), 0);
auto y = SDBMDimExpr::get(dialect(), 1);
// Check the simplification patterns in addition where one of the variables is
// cancelled out and the result remains an SDBM.
EXPECT_EQ(-(x + A) + ((x + B) - y), -(y + (A - B)));
EXPECT_EQ((x + A) + ((y + B) - x), (y + B) + A);
EXPECT_EQ(((x + A) - y) + (-(x + B)), -(y + (B - A)));
EXPECT_EQ(((x + A) - y) + (y + B), (x + A) + B);
}
TEST(SDBMOperators, Diff) {
auto expr = dim(0) - dim(1);
auto diffExpr = expr.dyn_cast<SDBMDiffExpr>();