Overhaul the SDBM expression kind hierarchy

Swap the allowed nesting of sum and diff expressions: now a diff expression can
contain a sum expression, but only on the left hand side.  A difference of two
expressions sum must be canonicalized by grouping their constant terms in a
single expression.  This change of sturcture became possible thanks to the
introduction of the "direct" super-kind.  It is necessary to enable support of
sum expressions on the left hand side of the stripe expression.

SDBM expressions are now grouped into the following structure
- expression
  - varying
    - direct
      - sum <- (term, constant)
      - term
        - symbol
        - dimension
        - stripe <- (term, constant)
    - negation <- (direct)
    - difference <- (direct, term)
  - constant
The notation <- (...) denotes the types of subexpressions a compound
expression can combine.

PiperOrigin-RevId: 269337222
This commit is contained in:
Alex Zinenko 2019-09-16 08:15:25 -07:00 committed by A. Unique TensorFlower
parent e94db619d9
commit cb3ecb5291
5 changed files with 302 additions and 137 deletions

View File

@ -68,6 +68,28 @@ class SDBMSymbolExpr;
/// not combine in more cases than they do. This choice may be reconsidered in /// not combine in more cases than they do. This choice may be reconsidered in
/// the future. /// the future.
/// ///
/// SDBM expressions are grouped into the following structure
/// - expression
/// - varying
/// - direct
/// - sum <- (term, constant)
/// - term
/// - symbol
/// - dimension
/// - stripe <- (term, constant)
/// - negation <- (direct)
/// - difference <- (direct, term)
/// - constant
/// The notation <- (...) denotes the types of subexpressions a compound
/// expression can combine. The tree of subexpressions essentially imposes the
/// following canonicalization rules:
/// - constants are always folded;
/// - constants can only appear on the RHS of an expression;
/// - double negation must be elided;
/// - an additive constant term is only allowed in a sum expression, and
/// should be sunk into the nearest such expression in the tree;
/// - zero constant expression can only appear at the top level.
///
/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by /// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by
/// an MLIRContext, and should be used by-value. They are uniqued in the /// an MLIRContext, and should be used by-value. They are uniqued in the
/// MLIRContext and immortal. /// MLIRContext and immortal.
@ -208,39 +230,42 @@ public:
} }
}; };
/// SDBM sum expression. LHS is a varying expression and RHS is always a /// SDBM sum expression. LHS is a term expression and RHS is a constant.
/// constant expression.
class SDBMSumExpr : public SDBMDirectExpr { class SDBMSumExpr : public SDBMDirectExpr {
public: public:
using ImplType = detail::SDBMBinaryExprStorage; using ImplType = detail::SDBMBinaryExprStorage;
using SDBMDirectExpr::SDBMDirectExpr; using SDBMDirectExpr::SDBMDirectExpr;
/// Obtain or create a sum expression unique'ed in the given context. /// Obtain or create a sum expression unique'ed in the given context.
static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs); static SDBMSumExpr get(SDBMTermExpr lhs, SDBMConstantExpr rhs);
static bool isClassFor(const SDBMExpr &expr) { static bool isClassFor(const SDBMExpr &expr) {
SDBMExprKind kind = expr.getKind(); SDBMExprKind kind = expr.getKind();
return kind == SDBMExprKind::Add; return kind == SDBMExprKind::Add;
} }
SDBMVaryingExpr getLHS() const; SDBMTermExpr getLHS() const;
SDBMConstantExpr getRHS() const; SDBMConstantExpr getRHS() const;
}; };
/// SDBM difference expression. Both LHS and RHS are SDBM term expressions. /// SDBM difference expression. LHS is a direct expression, i.e. it may be a
/// sum of a term and a constant. RHS is a term expression. Thus the
/// expression (t1 - t2 + C) with term expressions t1,t2 is represented as
/// diff(sum(t1, C), t2)
/// and it is possible to extract the constant factor without negating it.
class SDBMDiffExpr : public SDBMVaryingExpr { class SDBMDiffExpr : public SDBMVaryingExpr {
public: public:
using ImplType = detail::SDBMDiffExprStorage; using ImplType = detail::SDBMDiffExprStorage;
using SDBMVaryingExpr::SDBMVaryingExpr; using SDBMVaryingExpr::SDBMVaryingExpr;
/// Obtain or create a difference expression unique'ed in the given context. /// Obtain or create a difference expression unique'ed in the given context.
static SDBMDiffExpr get(SDBMTermExpr lhs, SDBMTermExpr rhs); static SDBMDiffExpr get(SDBMDirectExpr lhs, SDBMTermExpr rhs);
static bool isClassFor(const SDBMExpr &expr) { static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Diff; return expr.getKind() == SDBMExprKind::Diff;
} }
SDBMTermExpr getLHS() const; SDBMDirectExpr getLHS() const;
SDBMTermExpr getRHS() const; SDBMTermExpr getRHS() const;
}; };
@ -319,13 +344,13 @@ public:
using SDBMVaryingExpr::SDBMVaryingExpr; using SDBMVaryingExpr::SDBMVaryingExpr;
/// Obtain or create a negation expression unique'ed in the given context. /// Obtain or create a negation expression unique'ed in the given context.
static SDBMNegExpr get(SDBMTermExpr var); static SDBMNegExpr get(SDBMDirectExpr var);
static bool isClassFor(const SDBMExpr &expr) { static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Neg; return expr.getKind() == SDBMExprKind::Neg;
} }
SDBMTermExpr getVar() const; SDBMDirectExpr getVar() const;
}; };
/// A visitor class for SDBM expressions. Calls the kind-specific function /// A visitor class for SDBM expressions. Calls the kind-specific function
@ -490,22 +515,22 @@ template <> struct DenseMapInfo<mlir::SDBMExpr> {
} }
}; };
// SDBMVaryingExpr hash just like pointers. // SDBMDirectExpr hash just like pointers.
template <> struct DenseMapInfo<mlir::SDBMVaryingExpr> { template <> struct DenseMapInfo<mlir::SDBMDirectExpr> {
static mlir::SDBMVaryingExpr getEmptyKey() { static mlir::SDBMDirectExpr getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::SDBMVaryingExpr( return mlir::SDBMDirectExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer)); static_cast<mlir::SDBMExpr::ImplType *>(pointer));
} }
static mlir::SDBMVaryingExpr getTombstoneKey() { static mlir::SDBMDirectExpr getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::SDBMVaryingExpr( return mlir::SDBMDirectExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer)); static_cast<mlir::SDBMExpr::ImplType *>(pointer));
} }
static unsigned getHashValue(mlir::SDBMVaryingExpr expr) { static unsigned getHashValue(mlir::SDBMDirectExpr expr) {
return expr.hash_value(); return expr.hash_value();
} }
static bool isEqual(mlir::SDBMVaryingExpr lhs, mlir::SDBMVaryingExpr rhs) { static bool isEqual(mlir::SDBMDirectExpr lhs, mlir::SDBMDirectExpr rhs) {
return lhs == rhs; return lhs == rhs;
} }
}; };

View File

@ -161,14 +161,14 @@ void SDBMExpr::print(raw_ostream &os) const {
Printer(raw_ostream &ostream) : prn(ostream) {} Printer(raw_ostream &ostream) : prn(ostream) {}
void visitSum(SDBMSumExpr expr) { void visitSum(SDBMSumExpr expr) {
visitVarying(expr.getLHS()); visit(expr.getLHS());
prn << " + "; prn << " + ";
visitConstant(expr.getRHS()); visit(expr.getRHS());
} }
void visitDiff(SDBMDiffExpr expr) { void visitDiff(SDBMDiffExpr expr) {
visitTerm(expr.getLHS()); visit(expr.getLHS());
prn << " - "; prn << " - ";
visitTerm(expr.getRHS()); visit(expr.getRHS());
} }
void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); } void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); } void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
@ -178,8 +178,13 @@ void SDBMExpr::print(raw_ostream &os) const {
visitConstant(expr.getStripeFactor()); visitConstant(expr.getStripeFactor());
} }
void visitNeg(SDBMNegExpr expr) { void visitNeg(SDBMNegExpr expr) {
bool isSum = expr.getVar().isa<SDBMSumExpr>();
prn << '-'; prn << '-';
visitTerm(expr.getVar()); if (isSum)
prn << '(';
visit(expr.getVar());
if (isSum)
prn << ')';
} }
void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); } void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
@ -199,7 +204,7 @@ namespace {
struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> { struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
// Any term expression is wrapped into a negation expression. // Any term expression is wrapped into a negation expression.
// -(x) = -x // -(x) = -x
SDBMExpr visitTerm(SDBMTermExpr expr) { return SDBMNegExpr::get(expr); } SDBMExpr visitDirect(SDBMDirectExpr expr) { return SDBMNegExpr::get(expr); }
// A negation expression is unwrapped. // A negation expression is unwrapped.
// -(-x) = x // -(-x) = x
SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); } SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
@ -207,15 +212,20 @@ struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
SDBMExpr visitConstant(SDBMConstantExpr expr) { SDBMExpr visitConstant(SDBMConstantExpr expr) {
return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue()); return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
} }
// Both terms of the sum are negated recursively.
SDBMExpr visitSum(SDBMSumExpr expr) { // Terms of a difference are interchanged. Since only the LHS of a diff
return SDBMSumExpr::get(visit(expr.getLHS()).cast<SDBMVaryingExpr>(), // expression is allowed to be a sum with a constant, we need to recreate the
visit(expr.getRHS()).cast<SDBMConstantExpr>()); // sum with the negated value:
} // -((x + C) - y) = (y - C) - x.
// Terms of a difference are interchanged.
// -(x - y) = y - x
SDBMExpr visitDiff(SDBMDiffExpr expr) { SDBMExpr visitDiff(SDBMDiffExpr expr) {
return SDBMDiffExpr::get(expr.getRHS(), expr.getLHS()); // If the LHS is just a term, we can do straightforward interchange.
if (auto term = expr.getLHS().dyn_cast<SDBMTermExpr>())
return SDBMDiffExpr::get(expr.getRHS(), term);
auto sum = expr.getLHS().cast<SDBMSumExpr>();
auto cst = visitConstant(sum.getRHS()).cast<SDBMConstantExpr>();
return SDBMDiffExpr::get(SDBMSumExpr::get(expr.getRHS(), cst),
sum.getLHS());
} }
}; };
} // namespace } // namespace
@ -226,7 +236,7 @@ SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); }
// SDBMSumExpr // SDBMSumExpr
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) { SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
assert(lhs && "expected SDBM variable expression"); assert(lhs && "expected SDBM variable expression");
assert(rhs && "expected SDBM constant"); assert(rhs && "expected SDBM constant");
@ -242,8 +252,8 @@ SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs); /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
} }
SDBMVaryingExpr SDBMSumExpr::getLHS() const { SDBMTermExpr SDBMSumExpr::getLHS() const {
return static_cast<ImplType *>(impl)->lhs; return static_cast<ImplType *>(impl)->lhs.cast<SDBMTermExpr>();
} }
SDBMConstantExpr SDBMSumExpr::getRHS() const { SDBMConstantExpr SDBMSumExpr::getRHS() const {
@ -289,24 +299,126 @@ AffineExpr SDBMExpr::getAsAffineExpr() const {
return converter.visit(*this); return converter.visit(*this);
} }
// Given a direct expression `expr`, add the given constant to it and pass the
// resulting expression to `builder` before returning its result. If the
// expression is already a sum expression, update its constant and extract the
// LHS if the constant becomes zero. Otherwise, construct a sum expression.
template <typename Result>
Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant, bool negated,
llvm::function_ref<Result(SDBMDirectExpr)> builder) {
SDBMDialect *dialect = expr.getDialect();
if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
if (negated)
constant = sumExpr.getRHS().getValue() - constant;
else
constant += sumExpr.getRHS().getValue();
if (constant != 0) {
auto sum = SDBMSumExpr::get(sumExpr.getLHS(),
SDBMConstantExpr::get(dialect, constant));
return builder(sum);
} else {
return builder(sumExpr.getLHS());
}
}
if (constant != 0)
return builder(SDBMSumExpr::get(
expr.cast<SDBMTermExpr>(),
SDBMConstantExpr::get(dialect, negated ? -constant : constant)));
return expr;
}
// Construct an expression lhs + constant while maintaining the canonical form
// of the SDBM expressions, in particular sink the constant expression to the
// nearest sum expression in the left subtree of the expresison tree.
static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) {
if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
return addConstantAndSink<SDBMExpr>(
lhsDiff.getLHS(), constant, /*negated=*/false,
[lhsDiff](SDBMDirectExpr e) {
return SDBMDiffExpr::get(e, lhsDiff.getRHS());
});
if (auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>())
return addConstantAndSink<SDBMExpr>(
lhsNeg.getVar(), constant, /*negated=*/true,
[](SDBMDirectExpr e) { return SDBMNegExpr::get(e); });
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>())
return addConstantAndSink<SDBMExpr>(lhsSum, constant, /*negated=*/false,
[](SDBMDirectExpr e) { return e; });
if (constant != 0)
return SDBMSumExpr::get(lhs.cast<SDBMTermExpr>(),
SDBMConstantExpr::get(lhs.getDialect(), constant));
return lhs;
}
// Build a difference expression given a direct expression and a negation
// expression.
static SDBMExpr buildDiffExpr(SDBMDirectExpr lhs, SDBMNegExpr rhs) {
SDBMTermExpr lhsTerm, rhsTerm;
int lhsConstant = 0;
int64_t rhsConstant = 0;
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
lhsConstant = lhsSum.getRHS().getValue();
lhsTerm = lhsSum.getLHS();
} else {
lhsTerm = lhs.cast<SDBMTermExpr>();
}
if (auto rhsNegatedSum = rhs.getVar().dyn_cast<SDBMSumExpr>()) {
rhsTerm = rhsNegatedSum.getLHS();
rhsConstant = rhsNegatedSum.getRHS().getValue();
} else {
rhsTerm = rhs.getVar().cast<SDBMTermExpr>();
}
// Fold (x + C) - (x + D) = C - D.
if (lhsTerm == rhsTerm)
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant - rhsConstant);
return SDBMDiffExpr::get(
addConstantAndSink<SDBMDirectExpr>(lhs, -rhsConstant, /*negated=*/false,
[](SDBMDirectExpr e) { return e; }),
rhsTerm);
}
// Try folding an expression (lhs + rhs) where at least one of the operands
// contains a negated variable, i.e. is a negation or a difference expression.
static SDBMExpr foldSumDiff(SDBMExpr lhs, SDBMExpr rhs) {
// If exactly one of LHS, RHS is a negation expression, we can construct
// a difference expression, which is a special kind in SDBM.
auto lhsDirect = lhs.dyn_cast<SDBMDirectExpr>();
auto rhsDirect = rhs.dyn_cast<SDBMDirectExpr>();
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
if (lhsDirect && rhsNeg)
return buildDiffExpr(lhsDirect, rhsNeg);
if (lhsNeg && rhsDirect)
return buildDiffExpr(rhsDirect, lhsNeg);
// If a subexpression appears in a diff expression on the LHS(RHS) of a
// sum expression where it also appears on the RHS(LHS) with the opposite
// sign, we can simplify it away and obtain the SDBM form.
// x - (x - C) = -(x - C) + x = C
// (x - C) - x = -x + (x - C) = -C
auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
if (lhsNeg && rhsDiff && lhsNeg.getVar() == rhsDiff.getLHS())
return SDBMNegExpr::get(rhsDiff.getRHS());
if (lhsDirect && rhsDiff && lhsDirect == rhsDiff.getRHS())
return rhsDiff.getLHS();
if (lhsDiff && rhsNeg && lhsDiff.getLHS() == rhsNeg.getVar())
return SDBMNegExpr::get(lhsDiff.getRHS());
if (rhsDirect && lhsDiff && rhsDirect == lhsDiff.getRHS())
return lhsDiff.getLHS();
return {};
}
Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
struct Converter : public AffineExprVisitor<Converter, SDBMExpr> { struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) { SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
// Attempt to recover a stripe expression. Because AffineExprs don't have
// a first-class difference kind, we check for both x + -1 * (x mod C) and
// -1 * (x mod C) + x cases.
AffineExprMatcher x, C, m;
AffineExprMatcher pattern1 = ((x % C) * m) + x;
AffineExprMatcher pattern2 = x + ((x % C) * m);
if ((pattern1.match(expr) && m.getMatchedConstantValue() == -1) ||
(pattern2.match(expr) && m.getMatchedConstantValue() == -1)) {
if (auto convertedLHS = visit(x.matched())) {
// TODO(ntv): return convertedLHS.stripe(C);
return SDBMStripeExpr::get(
convertedLHS.cast<SDBMTermExpr>(),
visit(C.matched()).cast<SDBMConstantExpr>());
}
}
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
if (!lhs || !rhs) if (!lhs || !rhs)
return {}; return {};
@ -314,29 +426,22 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
// In a "add" AffineExpr, the constant always appears on the right. If // In a "add" AffineExpr, the constant always appears on the right. If
// there were two constants, they would have been folded away. // there were two constants, they would have been folded away.
assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression"); assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
// SDBM accepts LHS variables and RHS constants in a sum. // If RHS is a constant, we can always extend the SDBM expression to
auto lhsVar = lhs.dyn_cast<SDBMVaryingExpr>(); // include it by sinking the constant into the nearest sum expresion.
auto rhsVar = rhs.dyn_cast<SDBMVaryingExpr>(); if (auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>()) {
if (rhsConstant && lhsVar) assert(!lhs.isa<SDBMSumExpr>() && "unexpected non-canonicalized sum");
return SDBMSumExpr::get(lhsVar, rhsConstant);
// The sum of a negated variable and a non-negated variable is a int64_t constant = rhsConstant.getValue();
// difference, supported as a special kind in SDBM. Because AffineExprs auto varying = lhs.dyn_cast<SDBMVaryingExpr>();
// don't have first-class difference kind, check both LHS and RHS for assert(varying && "unexpected uncanonicalized sum of constants");
// negation. return addConstant(varying, constant);
auto lhsPos = lhs.dyn_cast<SDBMTermExpr>(); }
auto rhsPos = rhs.dyn_cast<SDBMTermExpr>();
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
if (lhsNeg && rhsVar)
return SDBMDiffExpr::get(rhsPos, lhsNeg.getVar());
if (rhsNeg && lhsVar)
return SDBMDiffExpr::get(lhsPos, rhsNeg.getVar());
// Other cases don't fit into SDBM. // Try building a difference expression if one of the values is negated,
return {}; // or check if a difference on either hand side cancels out the outer term
// so as to remain correct within SDBM. Return null otherwise.
return foldSumDiff(lhs, rhs);
} }
SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) { SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
@ -367,9 +472,13 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
// The only supported "multiplication" expression is an SDBM is dimension // The only supported "multiplication" expression is an SDBM is dimension
// negation, that is a product of dimension and constant -1. // negation, that is a product of dimension and constant -1.
auto lhsVar = lhs.dyn_cast<SDBMTermExpr>(); if (rhsConstant.getValue() != -1)
if (lhsVar && rhsConstant.getValue() == -1) return {};
if (auto lhsVar = lhs.dyn_cast<SDBMTermExpr>())
return SDBMNegExpr::get(lhsVar); return SDBMNegExpr::get(lhsVar);
if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
return SDBMNegator().visitDiff(lhsDiff);
// Other multiplications are not allowed in SDBM. // Other multiplications are not allowed in SDBM.
return {}; return {};
@ -383,7 +492,7 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
// 'mod' can only be converted to SDBM if its LHS is a variable // 'mod' can only be converted to SDBM if its LHS is a variable
// and its RHS is a constant. Then it `x mod c = x - x stripe c`. // and its RHS is a constant. Then it `x mod c = x - x stripe c`.
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>(); auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
auto lhsVar = rhs.dyn_cast<SDBMTermExpr>(); auto lhsVar = lhs.dyn_cast<SDBMTermExpr>();
if (!lhsVar || !rhsConstant) if (!lhsVar || !rhsConstant)
return {}; return {};
return SDBMDiffExpr::get(lhsVar, return SDBMDiffExpr::get(lhsVar,
@ -418,7 +527,7 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
// SDBMDiffExpr // SDBMDiffExpr
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
SDBMDiffExpr SDBMDiffExpr::get(SDBMTermExpr lhs, SDBMTermExpr rhs) { SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
assert(lhs && "expected SDBM dimension"); assert(lhs && "expected SDBM dimension");
assert(rhs && "expected SDBM dimension"); assert(rhs && "expected SDBM dimension");
@ -427,7 +536,7 @@ SDBMDiffExpr SDBMDiffExpr::get(SDBMTermExpr lhs, SDBMTermExpr rhs) {
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs); /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
} }
SDBMTermExpr SDBMDiffExpr::getLHS() const { SDBMDirectExpr SDBMDiffExpr::getLHS() const {
return static_cast<ImplType *>(impl)->lhs; return static_cast<ImplType *>(impl)->lhs;
} }
@ -526,31 +635,36 @@ int64_t SDBMConstantExpr::getValue() const {
// SDBMNegExpr // SDBMNegExpr
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
SDBMNegExpr SDBMNegExpr::get(SDBMTermExpr var) { SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
assert(var && "expected non-null SDBM variable expression"); assert(var && "expected non-null SDBM direct expression");
StorageUniquer &uniquer = var.getDialect()->getUniquer(); StorageUniquer &uniquer = var.getDialect()->getUniquer();
return uniquer.get<detail::SDBMNegExprStorage>( return uniquer.get<detail::SDBMNegExprStorage>(
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var); /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
} }
SDBMTermExpr SDBMNegExpr::getVar() const { SDBMDirectExpr SDBMNegExpr::getVar() const {
return static_cast<ImplType *>(impl)->dim; return static_cast<ImplType *>(impl)->expr;
} }
namespace mlir { namespace mlir {
namespace ops_assertions { namespace ops_assertions {
SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) { SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) {
// If one of the operands is a negation, take a difference rather than a sum. if (auto folded = foldSumDiff(lhs, rhs))
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>(); return folded;
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>(); assert(!(lhs.isa<SDBMNegExpr>() && rhs.isa<SDBMNegExpr>()) &&
assert(!(lhsNeg && rhsNeg) && "a sum of negated expressions is a negation of " "a sum of negated expressions is a negation of a sum of variables and "
"a sum of variables and not a correct SDBM"); "not a correct SDBM");
if (lhsNeg)
return rhs - lhsNeg.getVar(); // Fold (x - y) + (y - x) = 0.
if (rhsNeg) auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
return lhs - rhsNeg.getVar(); auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
if (lhsDiff && rhsDiff) {
if (lhsDiff.getLHS() == rhsDiff.getRHS() &&
lhsDiff.getRHS() == rhsDiff.getLHS())
return SDBMConstantExpr::get(lhs.getDialect(), 0);
}
// If LHS is a constant and RHS is not, swap the order to get into a supported // If LHS is a constant and RHS is not, swap the order to get into a supported
// sum case. From now on, RHS must be a constant. // sum case. From now on, RHS must be a constant.
@ -562,26 +676,11 @@ SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) {
} }
assert(rhsConstant && "at least one operand must be a constant"); assert(rhsConstant && "at least one operand must be a constant");
// If LHS is another sum, first compute the sum of its variable // Constant-fold if LHS is also a constant.
// part with the other argument and then add the constant part to enable
// constant folding (the variable part may, e.g., be a negation that requires
// to enter this function again).
auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
if (lhsSum)
return lhsSum.getLHS() +
(lhsSum.getRHS().getValue() + rhsConstant.getValue());
// Constant-fold if LHS is a constant.
if (lhsConstant) if (lhsConstant)
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() + return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
rhsConstant.getValue()); rhsConstant.getValue());
return addConstant(lhs.cast<SDBMVaryingExpr>(), rhsConstant.getValue());
// Fold x + 0 == x.
if (rhsConstant.getValue() == 0)
return lhs;
return SDBMSumExpr::get(lhs.cast<SDBMVaryingExpr>(),
rhs.cast<SDBMConstantExpr>());
} }
SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) { SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) {
@ -608,25 +707,7 @@ SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) {
if (lhsConstant) if (lhsConstant)
return -rhs + lhsConstant; return -rhs + lhsConstant;
// Hoist constant factors outside the difference if any of sides is a sum: return buildDiffExpr(lhs.cast<SDBMDirectExpr>(), (-rhs).cast<SDBMNegExpr>());
// (x + A) - (y - B) == x - y + (A - B).
// If either LHS or RHS is a sum, collect the constant values separately and
// update LHS and RHS to point to the variable part of the sum.
auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
auto rhsSum = rhs.dyn_cast<SDBMSumExpr>();
int64_t value = 0;
if (lhsSum) {
value += lhsSum.getRHS().getValue();
lhs = lhsSum.getLHS();
}
if (rhsSum) {
value -= rhsSum.getRHS().getValue();
rhs = rhsSum.getLHS();
}
// This calls into operator+ for futher simplification in case value == 0.
return SDBMDiffExpr::get(lhs.cast<SDBMTermExpr>(), rhs.cast<SDBMTermExpr>()) +
value;
} }
SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) { SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) {

View File

@ -43,7 +43,7 @@ struct SDBMExprStorage : public StorageUniquer::BaseStorage {
// Storage class for SDBM sum and stripe expressions. // Storage class for SDBM sum and stripe expressions.
struct SDBMBinaryExprStorage : public SDBMExprStorage { struct SDBMBinaryExprStorage : public SDBMExprStorage {
using KeyTy = std::pair<SDBMVaryingExpr, SDBMConstantExpr>; using KeyTy = std::pair<SDBMDirectExpr, SDBMConstantExpr>;
bool operator==(const KeyTy &key) const { bool operator==(const KeyTy &key) const {
return std::get<0>(key) == lhs && std::get<1>(key) == rhs; return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
@ -58,13 +58,13 @@ struct SDBMBinaryExprStorage : public SDBMExprStorage {
return result; return result;
} }
SDBMVaryingExpr lhs; SDBMDirectExpr lhs;
SDBMConstantExpr rhs; SDBMConstantExpr rhs;
}; };
// Storage class for SDBM difference expressions. // Storage class for SDBM difference expressions.
struct SDBMDiffExprStorage : public SDBMExprStorage { struct SDBMDiffExprStorage : public SDBMExprStorage {
using KeyTy = std::pair<SDBMTermExpr, SDBMTermExpr>; using KeyTy = std::pair<SDBMDirectExpr, SDBMTermExpr>;
bool operator==(const KeyTy &key) const { bool operator==(const KeyTy &key) const {
return std::get<0>(key) == lhs && std::get<1>(key) == rhs; return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
@ -79,7 +79,7 @@ struct SDBMDiffExprStorage : public SDBMExprStorage {
return result; return result;
} }
SDBMTermExpr lhs; SDBMDirectExpr lhs;
SDBMTermExpr rhs; SDBMTermExpr rhs;
}; };
@ -117,19 +117,19 @@ struct SDBMTermExprStorage : public SDBMExprStorage {
// Storage class for SDBM negation expressions. // Storage class for SDBM negation expressions.
struct SDBMNegExprStorage : public SDBMExprStorage { struct SDBMNegExprStorage : public SDBMExprStorage {
using KeyTy = SDBMTermExpr; using KeyTy = SDBMDirectExpr;
bool operator==(const KeyTy &key) const { return key == dim; } bool operator==(const KeyTy &key) const { return key == expr; }
static SDBMNegExprStorage * static SDBMNegExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<SDBMNegExprStorage>(); auto *result = allocator.allocate<SDBMNegExprStorage>();
result->dim = key; result->expr = key;
result->dialect = key.getDialect(); result->dialect = key.getDialect();
return result; return result;
} }
SDBMTermExpr dim; SDBMDirectExpr expr;
}; };
} // end namespace detail } // end namespace detail

View File

@ -161,13 +161,13 @@ TEST_FUNC(SDBM_StripeTightening) {
SmallVector<SDBMExpr, 4> eqs, ineqs; SmallVector<SDBMExpr, 4> eqs, ineqs;
sdbm.getSDBMExpressions(dialect(), ineqs, eqs); sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
// CHECK: s0 # 3 - d0 + -2 // CHECK: s0 # 3 + -2 - d0
// CHECK-EMPTY: // CHECK-EMPTY:
for (auto ineq : ineqs) for (auto ineq : ineqs)
ineq.print(llvm::outs() << '\n'); ineq.print(llvm::outs() << '\n');
llvm::outs() << "\n"; llvm::outs() << "\n";
// CHECK-DAG: d1 - d0 + -42 // CHECK-DAG: d1 + -42 - d0
// CHECK-DAG: d0 - s0 # 3 # 5 // CHECK-DAG: d0 - s0 # 3 # 5
for (auto eq : eqs) for (auto eq : eqs)
eq.print(llvm::outs() << '\n'); eq.print(llvm::outs() << '\n');

View File

@ -76,6 +76,28 @@ TEST(SDBMOperators, AddFolding) {
auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0); auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0);
EXPECT_EQ(inverted, expr); EXPECT_EQ(inverted, expr);
// Check that opposite values cancel each other, and that we elide the zero
// constant.
expr = dim(0) + 42;
auto onlyDim = expr - 42;
EXPECT_EQ(onlyDim, dim(0));
// Check that we can sink a constant under a negation.
expr = -(dim(0) + 2);
auto negatedSum = (expr + 10).dyn_cast<SDBMNegExpr>();
ASSERT_TRUE(negatedSum);
auto sum = negatedSum.getVar().dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(sum);
EXPECT_EQ(sum.getRHS().getValue(), -8);
// Sum with zero is the same as the original expression.
EXPECT_EQ(dim(0) + 0, dim(0));
// Sum of opposite differences is zero.
auto diffOfDiffs =
((dim(0) - dim(1)) + (dim(1) - dim(0))).dyn_cast<SDBMConstantExpr>();
EXPECT_EQ(diffOfDiffs.getValue(), 0);
} }
TEST(SDBMOperators, Diff) { TEST(SDBMOperators, Diff) {
@ -101,6 +123,43 @@ TEST(SDBMOperators, DiffFolding) {
constantExpr = zero.dyn_cast<SDBMConstantExpr>(); constantExpr = zero.dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(constantExpr); ASSERT_TRUE(constantExpr);
EXPECT_EQ(constantExpr.getValue(), 0); EXPECT_EQ(constantExpr.getValue(), 0);
// Check that the constant terms in difference-of-sums are folded.
// (d0 - 3) - (d1 - 5) = (d0 + 2) - d1
auto diffOfSums = ((dim(0) - 3) - (dim(1) - 5)).dyn_cast<SDBMDiffExpr>();
ASSERT_TRUE(diffOfSums);
auto lhs = diffOfSums.getLHS().dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(lhs);
EXPECT_EQ(lhs.getLHS(), dim(0));
EXPECT_EQ(lhs.getRHS().getValue(), 2);
EXPECT_EQ(diffOfSums.getRHS(), dim(1));
// Check that identical dimensions with opposite signs cancel each other.
auto cstOnly = ((dim(0) + 42) - dim(0)).dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(cstOnly);
EXPECT_EQ(cstOnly.getValue(), 42);
// Check that identical terms in sum of diffs cancel out.
auto dimOnly = (-dim(0) + (dim(0) - dim(1)));
EXPECT_EQ(dimOnly, -dim(1));
dimOnly = (dim(0) - dim(1)) + (-dim(0));
EXPECT_EQ(dimOnly, -dim(1));
dimOnly = (dim(0) - dim(1)) + dim(1);
EXPECT_EQ(dimOnly, dim(0));
dimOnly = dim(0) + (dim(1) - dim(0));
EXPECT_EQ(dimOnly, dim(1));
// Top-level zero constant is fine.
cstOnly = (-symb(1) + symb(1)).dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(cstOnly);
EXPECT_EQ(cstOnly.getValue(), 0);
}
TEST(SDBMOperators, Negate) {
auto sum = dim(0) + 3;
auto negated = (-sum).dyn_cast<SDBMNegExpr>();
ASSERT_TRUE(negated);
EXPECT_EQ(negated.getVar(), sum);
} }
TEST(SDBMOperators, Stripe) { TEST(SDBMOperators, Stripe) {
@ -324,11 +383,11 @@ TEST(SDBMExpr, AffineRoundTrip) {
// Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a // Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a
// deeper expression tree. // deeper expression tree.
auto diff = SDBMDiffExpr::get(outerStripe, stripe); auto sum = SDBMSumExpr::get(outerStripe, cst2);
auto sum = SDBMSumExpr::get(diff, cst2); auto diff = SDBMDiffExpr::get(sum, stripe);
roundtripped = SDBMExpr::tryConvertAffineExpr(sum.getAsAffineExpr()); roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue()); ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(sum)); EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
} }
TEST(SDBMExpr, MatchStripeMulPattern) { TEST(SDBMExpr, MatchStripeMulPattern) {