diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h index f643d92b0977..2d0fd0cf4702 100644 --- a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h +++ b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h @@ -68,6 +68,28 @@ class SDBMSymbolExpr; /// not combine in more cases than they do. This choice may be reconsidered in /// the future. /// +/// SDBM expressions are grouped into the following structure +/// - expression +/// - varying +/// - direct +/// - sum <- (term, constant) +/// - term +/// - symbol +/// - dimension +/// - stripe <- (term, constant) +/// - negation <- (direct) +/// - difference <- (direct, term) +/// - constant +/// The notation <- (...) denotes the types of subexpressions a compound +/// expression can combine. The tree of subexpressions essentially imposes the +/// following canonicalization rules: +/// - constants are always folded; +/// - constants can only appear on the RHS of an expression; +/// - double negation must be elided; +/// - an additive constant term is only allowed in a sum expression, and +/// should be sunk into the nearest such expression in the tree; +/// - zero constant expression can only appear at the top level. +/// /// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by /// an MLIRContext, and should be used by-value. They are uniqued in the /// MLIRContext and immortal. @@ -208,39 +230,42 @@ public: } }; -/// SDBM sum expression. LHS is a varying expression and RHS is always a -/// constant expression. +/// SDBM sum expression. LHS is a term expression and RHS is a constant. class SDBMSumExpr : public SDBMDirectExpr { public: using ImplType = detail::SDBMBinaryExprStorage; using SDBMDirectExpr::SDBMDirectExpr; /// Obtain or create a sum expression unique'ed in the given context. - static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs); + static SDBMSumExpr get(SDBMTermExpr lhs, SDBMConstantExpr rhs); static bool isClassFor(const SDBMExpr &expr) { SDBMExprKind kind = expr.getKind(); return kind == SDBMExprKind::Add; } - SDBMVaryingExpr getLHS() const; + SDBMTermExpr getLHS() const; SDBMConstantExpr getRHS() const; }; -/// SDBM difference expression. Both LHS and RHS are SDBM term expressions. +/// SDBM difference expression. LHS is a direct expression, i.e. it may be a +/// sum of a term and a constant. RHS is a term expression. Thus the +/// expression (t1 - t2 + C) with term expressions t1,t2 is represented as +/// diff(sum(t1, C), t2) +/// and it is possible to extract the constant factor without negating it. class SDBMDiffExpr : public SDBMVaryingExpr { public: using ImplType = detail::SDBMDiffExprStorage; using SDBMVaryingExpr::SDBMVaryingExpr; /// Obtain or create a difference expression unique'ed in the given context. - static SDBMDiffExpr get(SDBMTermExpr lhs, SDBMTermExpr rhs); + static SDBMDiffExpr get(SDBMDirectExpr lhs, SDBMTermExpr rhs); static bool isClassFor(const SDBMExpr &expr) { return expr.getKind() == SDBMExprKind::Diff; } - SDBMTermExpr getLHS() const; + SDBMDirectExpr getLHS() const; SDBMTermExpr getRHS() const; }; @@ -319,13 +344,13 @@ public: using SDBMVaryingExpr::SDBMVaryingExpr; /// Obtain or create a negation expression unique'ed in the given context. - static SDBMNegExpr get(SDBMTermExpr var); + static SDBMNegExpr get(SDBMDirectExpr var); static bool isClassFor(const SDBMExpr &expr) { return expr.getKind() == SDBMExprKind::Neg; } - SDBMTermExpr getVar() const; + SDBMDirectExpr getVar() const; }; /// A visitor class for SDBM expressions. Calls the kind-specific function @@ -490,22 +515,22 @@ template <> struct DenseMapInfo { } }; -// SDBMVaryingExpr hash just like pointers. -template <> struct DenseMapInfo { - static mlir::SDBMVaryingExpr getEmptyKey() { +// SDBMDirectExpr hash just like pointers. +template <> struct DenseMapInfo { + static mlir::SDBMDirectExpr getEmptyKey() { auto *pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::SDBMVaryingExpr( + return mlir::SDBMDirectExpr( static_cast(pointer)); } - static mlir::SDBMVaryingExpr getTombstoneKey() { + static mlir::SDBMDirectExpr getTombstoneKey() { auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::SDBMVaryingExpr( + return mlir::SDBMDirectExpr( static_cast(pointer)); } - static unsigned getHashValue(mlir::SDBMVaryingExpr expr) { + static unsigned getHashValue(mlir::SDBMDirectExpr expr) { return expr.hash_value(); } - static bool isEqual(mlir::SDBMVaryingExpr lhs, mlir::SDBMVaryingExpr rhs) { + static bool isEqual(mlir::SDBMDirectExpr lhs, mlir::SDBMDirectExpr rhs) { return lhs == rhs; } }; diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp index f1c02a36312c..96b6491776e9 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp @@ -161,14 +161,14 @@ void SDBMExpr::print(raw_ostream &os) const { Printer(raw_ostream &ostream) : prn(ostream) {} void visitSum(SDBMSumExpr expr) { - visitVarying(expr.getLHS()); + visit(expr.getLHS()); prn << " + "; - visitConstant(expr.getRHS()); + visit(expr.getRHS()); } void visitDiff(SDBMDiffExpr expr) { - visitTerm(expr.getLHS()); + visit(expr.getLHS()); prn << " - "; - visitTerm(expr.getRHS()); + visit(expr.getRHS()); } void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); } void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); } @@ -178,8 +178,13 @@ void SDBMExpr::print(raw_ostream &os) const { visitConstant(expr.getStripeFactor()); } void visitNeg(SDBMNegExpr expr) { + bool isSum = expr.getVar().isa(); prn << '-'; - visitTerm(expr.getVar()); + if (isSum) + prn << '('; + visit(expr.getVar()); + if (isSum) + prn << ')'; } void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); } @@ -199,7 +204,7 @@ namespace { struct SDBMNegator : public SDBMVisitor { // Any term expression is wrapped into a negation expression. // -(x) = -x - SDBMExpr visitTerm(SDBMTermExpr expr) { return SDBMNegExpr::get(expr); } + SDBMExpr visitDirect(SDBMDirectExpr expr) { return SDBMNegExpr::get(expr); } // A negation expression is unwrapped. // -(-x) = x SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); } @@ -207,15 +212,20 @@ struct SDBMNegator : public SDBMVisitor { SDBMExpr visitConstant(SDBMConstantExpr expr) { return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue()); } - // Both terms of the sum are negated recursively. - SDBMExpr visitSum(SDBMSumExpr expr) { - return SDBMSumExpr::get(visit(expr.getLHS()).cast(), - visit(expr.getRHS()).cast()); - } - // Terms of a difference are interchanged. - // -(x - y) = y - x + + // Terms of a difference are interchanged. Since only the LHS of a diff + // expression is allowed to be a sum with a constant, we need to recreate the + // sum with the negated value: + // -((x + C) - y) = (y - C) - x. SDBMExpr visitDiff(SDBMDiffExpr expr) { - return SDBMDiffExpr::get(expr.getRHS(), expr.getLHS()); + // If the LHS is just a term, we can do straightforward interchange. + if (auto term = expr.getLHS().dyn_cast()) + return SDBMDiffExpr::get(expr.getRHS(), term); + + auto sum = expr.getLHS().cast(); + auto cst = visitConstant(sum.getRHS()).cast(); + return SDBMDiffExpr::get(SDBMSumExpr::get(expr.getRHS(), cst), + sum.getLHS()); } }; } // namespace @@ -226,7 +236,7 @@ SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); } // SDBMSumExpr //===----------------------------------------------------------------------===// -SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) { +SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) { assert(lhs && "expected SDBM variable expression"); assert(rhs && "expected SDBM constant"); @@ -242,8 +252,8 @@ SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) { /*initFn=*/{}, static_cast(SDBMExprKind::Add), lhs, rhs); } -SDBMVaryingExpr SDBMSumExpr::getLHS() const { - return static_cast(impl)->lhs; +SDBMTermExpr SDBMSumExpr::getLHS() const { + return static_cast(impl)->lhs.cast(); } SDBMConstantExpr SDBMSumExpr::getRHS() const { @@ -289,24 +299,126 @@ AffineExpr SDBMExpr::getAsAffineExpr() const { return converter.visit(*this); } +// Given a direct expression `expr`, add the given constant to it and pass the +// resulting expression to `builder` before returning its result. If the +// expression is already a sum expression, update its constant and extract the +// LHS if the constant becomes zero. Otherwise, construct a sum expression. +template +Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant, bool negated, + llvm::function_ref builder) { + SDBMDialect *dialect = expr.getDialect(); + if (auto sumExpr = expr.dyn_cast()) { + if (negated) + constant = sumExpr.getRHS().getValue() - constant; + else + constant += sumExpr.getRHS().getValue(); + + if (constant != 0) { + auto sum = SDBMSumExpr::get(sumExpr.getLHS(), + SDBMConstantExpr::get(dialect, constant)); + return builder(sum); + } else { + return builder(sumExpr.getLHS()); + } + } + if (constant != 0) + return builder(SDBMSumExpr::get( + expr.cast(), + SDBMConstantExpr::get(dialect, negated ? -constant : constant))); + return expr; +} + +// Construct an expression lhs + constant while maintaining the canonical form +// of the SDBM expressions, in particular sink the constant expression to the +// nearest sum expression in the left subtree of the expresison tree. +static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) { + if (auto lhsDiff = lhs.dyn_cast()) + return addConstantAndSink( + lhsDiff.getLHS(), constant, /*negated=*/false, + [lhsDiff](SDBMDirectExpr e) { + return SDBMDiffExpr::get(e, lhsDiff.getRHS()); + }); + if (auto lhsNeg = lhs.dyn_cast()) + return addConstantAndSink( + lhsNeg.getVar(), constant, /*negated=*/true, + [](SDBMDirectExpr e) { return SDBMNegExpr::get(e); }); + if (auto lhsSum = lhs.dyn_cast()) + return addConstantAndSink(lhsSum, constant, /*negated=*/false, + [](SDBMDirectExpr e) { return e; }); + if (constant != 0) + return SDBMSumExpr::get(lhs.cast(), + SDBMConstantExpr::get(lhs.getDialect(), constant)); + return lhs; +} + +// Build a difference expression given a direct expression and a negation +// expression. +static SDBMExpr buildDiffExpr(SDBMDirectExpr lhs, SDBMNegExpr rhs) { + SDBMTermExpr lhsTerm, rhsTerm; + int lhsConstant = 0; + int64_t rhsConstant = 0; + + if (auto lhsSum = lhs.dyn_cast()) { + lhsConstant = lhsSum.getRHS().getValue(); + lhsTerm = lhsSum.getLHS(); + } else { + lhsTerm = lhs.cast(); + } + + if (auto rhsNegatedSum = rhs.getVar().dyn_cast()) { + rhsTerm = rhsNegatedSum.getLHS(); + rhsConstant = rhsNegatedSum.getRHS().getValue(); + } else { + rhsTerm = rhs.getVar().cast(); + } + + // Fold (x + C) - (x + D) = C - D. + if (lhsTerm == rhsTerm) + return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant - rhsConstant); + + return SDBMDiffExpr::get( + addConstantAndSink(lhs, -rhsConstant, /*negated=*/false, + [](SDBMDirectExpr e) { return e; }), + rhsTerm); +} + +// Try folding an expression (lhs + rhs) where at least one of the operands +// contains a negated variable, i.e. is a negation or a difference expression. +static SDBMExpr foldSumDiff(SDBMExpr lhs, SDBMExpr rhs) { + // If exactly one of LHS, RHS is a negation expression, we can construct + // a difference expression, which is a special kind in SDBM. + auto lhsDirect = lhs.dyn_cast(); + auto rhsDirect = rhs.dyn_cast(); + auto lhsNeg = lhs.dyn_cast(); + auto rhsNeg = rhs.dyn_cast(); + + if (lhsDirect && rhsNeg) + return buildDiffExpr(lhsDirect, rhsNeg); + if (lhsNeg && rhsDirect) + return buildDiffExpr(rhsDirect, lhsNeg); + + // If a subexpression appears in a diff expression on the LHS(RHS) of a + // sum expression where it also appears on the RHS(LHS) with the opposite + // sign, we can simplify it away and obtain the SDBM form. + // x - (x - C) = -(x - C) + x = C + // (x - C) - x = -x + (x - C) = -C + auto lhsDiff = lhs.dyn_cast(); + auto rhsDiff = rhs.dyn_cast(); + if (lhsNeg && rhsDiff && lhsNeg.getVar() == rhsDiff.getLHS()) + return SDBMNegExpr::get(rhsDiff.getRHS()); + if (lhsDirect && rhsDiff && lhsDirect == rhsDiff.getRHS()) + return rhsDiff.getLHS(); + if (lhsDiff && rhsNeg && lhsDiff.getLHS() == rhsNeg.getVar()) + return SDBMNegExpr::get(lhsDiff.getRHS()); + if (rhsDirect && lhsDiff && rhsDirect == lhsDiff.getRHS()) + return lhsDiff.getLHS(); + + return {}; +} + Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { struct Converter : public AffineExprVisitor { SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) { - // Attempt to recover a stripe expression. Because AffineExprs don't have - // a first-class difference kind, we check for both x + -1 * (x mod C) and - // -1 * (x mod C) + x cases. - AffineExprMatcher x, C, m; - AffineExprMatcher pattern1 = ((x % C) * m) + x; - AffineExprMatcher pattern2 = x + ((x % C) * m); - if ((pattern1.match(expr) && m.getMatchedConstantValue() == -1) || - (pattern2.match(expr) && m.getMatchedConstantValue() == -1)) { - if (auto convertedLHS = visit(x.matched())) { - // TODO(ntv): return convertedLHS.stripe(C); - return SDBMStripeExpr::get( - convertedLHS.cast(), - visit(C.matched()).cast()); - } - } auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); if (!lhs || !rhs) return {}; @@ -314,29 +426,22 @@ Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { // In a "add" AffineExpr, the constant always appears on the right. If // there were two constants, they would have been folded away. assert(!lhs.isa() && "non-canonical affine expression"); - auto rhsConstant = rhs.dyn_cast(); - // SDBM accepts LHS variables and RHS constants in a sum. - auto lhsVar = lhs.dyn_cast(); - auto rhsVar = rhs.dyn_cast(); - if (rhsConstant && lhsVar) - return SDBMSumExpr::get(lhsVar, rhsConstant); + // If RHS is a constant, we can always extend the SDBM expression to + // include it by sinking the constant into the nearest sum expresion. + if (auto rhsConstant = rhs.dyn_cast()) { + assert(!lhs.isa() && "unexpected non-canonicalized sum"); - // The sum of a negated variable and a non-negated variable is a - // difference, supported as a special kind in SDBM. Because AffineExprs - // don't have first-class difference kind, check both LHS and RHS for - // negation. - auto lhsPos = lhs.dyn_cast(); - auto rhsPos = rhs.dyn_cast(); - auto lhsNeg = lhs.dyn_cast(); - auto rhsNeg = rhs.dyn_cast(); - if (lhsNeg && rhsVar) - return SDBMDiffExpr::get(rhsPos, lhsNeg.getVar()); - if (rhsNeg && lhsVar) - return SDBMDiffExpr::get(lhsPos, rhsNeg.getVar()); + int64_t constant = rhsConstant.getValue(); + auto varying = lhs.dyn_cast(); + assert(varying && "unexpected uncanonicalized sum of constants"); + return addConstant(varying, constant); + } - // Other cases don't fit into SDBM. - return {}; + // Try building a difference expression if one of the values is negated, + // or check if a difference on either hand side cancels out the outer term + // so as to remain correct within SDBM. Return null otherwise. + return foldSumDiff(lhs, rhs); } SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) { @@ -367,9 +472,13 @@ Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { // The only supported "multiplication" expression is an SDBM is dimension // negation, that is a product of dimension and constant -1. - auto lhsVar = lhs.dyn_cast(); - if (lhsVar && rhsConstant.getValue() == -1) + if (rhsConstant.getValue() != -1) + return {}; + + if (auto lhsVar = lhs.dyn_cast()) return SDBMNegExpr::get(lhsVar); + if (auto lhsDiff = lhs.dyn_cast()) + return SDBMNegator().visitDiff(lhsDiff); // Other multiplications are not allowed in SDBM. return {}; @@ -383,7 +492,7 @@ Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { // 'mod' can only be converted to SDBM if its LHS is a variable // and its RHS is a constant. Then it `x mod c = x - x stripe c`. auto rhsConstant = rhs.dyn_cast(); - auto lhsVar = rhs.dyn_cast(); + auto lhsVar = lhs.dyn_cast(); if (!lhsVar || !rhsConstant) return {}; return SDBMDiffExpr::get(lhsVar, @@ -418,7 +527,7 @@ Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { // SDBMDiffExpr //===----------------------------------------------------------------------===// -SDBMDiffExpr SDBMDiffExpr::get(SDBMTermExpr lhs, SDBMTermExpr rhs) { +SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) { assert(lhs && "expected SDBM dimension"); assert(rhs && "expected SDBM dimension"); @@ -427,7 +536,7 @@ SDBMDiffExpr SDBMDiffExpr::get(SDBMTermExpr lhs, SDBMTermExpr rhs) { /*initFn=*/{}, static_cast(SDBMExprKind::Diff), lhs, rhs); } -SDBMTermExpr SDBMDiffExpr::getLHS() const { +SDBMDirectExpr SDBMDiffExpr::getLHS() const { return static_cast(impl)->lhs; } @@ -526,31 +635,36 @@ int64_t SDBMConstantExpr::getValue() const { // SDBMNegExpr //===----------------------------------------------------------------------===// -SDBMNegExpr SDBMNegExpr::get(SDBMTermExpr var) { - assert(var && "expected non-null SDBM variable expression"); +SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) { + assert(var && "expected non-null SDBM direct expression"); StorageUniquer &uniquer = var.getDialect()->getUniquer(); return uniquer.get( /*initFn=*/{}, static_cast(SDBMExprKind::Neg), var); } -SDBMTermExpr SDBMNegExpr::getVar() const { - return static_cast(impl)->dim; +SDBMDirectExpr SDBMNegExpr::getVar() const { + return static_cast(impl)->expr; } namespace mlir { namespace ops_assertions { SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) { - // If one of the operands is a negation, take a difference rather than a sum. - auto lhsNeg = lhs.dyn_cast(); - auto rhsNeg = rhs.dyn_cast(); - assert(!(lhsNeg && rhsNeg) && "a sum of negated expressions is a negation of " - "a sum of variables and not a correct SDBM"); - if (lhsNeg) - return rhs - lhsNeg.getVar(); - if (rhsNeg) - return lhs - rhsNeg.getVar(); + if (auto folded = foldSumDiff(lhs, rhs)) + return folded; + assert(!(lhs.isa() && rhs.isa()) && + "a sum of negated expressions is a negation of a sum of variables and " + "not a correct SDBM"); + + // Fold (x - y) + (y - x) = 0. + auto lhsDiff = lhs.dyn_cast(); + auto rhsDiff = rhs.dyn_cast(); + if (lhsDiff && rhsDiff) { + if (lhsDiff.getLHS() == rhsDiff.getRHS() && + lhsDiff.getRHS() == rhsDiff.getLHS()) + return SDBMConstantExpr::get(lhs.getDialect(), 0); + } // If LHS is a constant and RHS is not, swap the order to get into a supported // sum case. From now on, RHS must be a constant. @@ -562,26 +676,11 @@ SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) { } assert(rhsConstant && "at least one operand must be a constant"); - // If LHS is another sum, first compute the sum of its variable - // part with the other argument and then add the constant part to enable - // constant folding (the variable part may, e.g., be a negation that requires - // to enter this function again). - auto lhsSum = lhs.dyn_cast(); - if (lhsSum) - return lhsSum.getLHS() + - (lhsSum.getRHS().getValue() + rhsConstant.getValue()); - - // Constant-fold if LHS is a constant. + // Constant-fold if LHS is also a constant. if (lhsConstant) return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() + rhsConstant.getValue()); - - // Fold x + 0 == x. - if (rhsConstant.getValue() == 0) - return lhs; - - return SDBMSumExpr::get(lhs.cast(), - rhs.cast()); + return addConstant(lhs.cast(), rhsConstant.getValue()); } SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) { @@ -608,25 +707,7 @@ SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) { if (lhsConstant) return -rhs + lhsConstant; - // Hoist constant factors outside the difference if any of sides is a sum: - // (x + A) - (y - B) == x - y + (A - B). - // If either LHS or RHS is a sum, collect the constant values separately and - // update LHS and RHS to point to the variable part of the sum. - auto lhsSum = lhs.dyn_cast(); - auto rhsSum = rhs.dyn_cast(); - int64_t value = 0; - if (lhsSum) { - value += lhsSum.getRHS().getValue(); - lhs = lhsSum.getLHS(); - } - if (rhsSum) { - value -= rhsSum.getRHS().getValue(); - rhs = rhsSum.getLHS(); - } - - // This calls into operator+ for futher simplification in case value == 0. - return SDBMDiffExpr::get(lhs.cast(), rhs.cast()) + - value; + return buildDiffExpr(lhs.cast(), (-rhs).cast()); } SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) { diff --git a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h index b202ab5efb4f..0441200754cb 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h +++ b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h @@ -43,7 +43,7 @@ struct SDBMExprStorage : public StorageUniquer::BaseStorage { // Storage class for SDBM sum and stripe expressions. struct SDBMBinaryExprStorage : public SDBMExprStorage { - using KeyTy = std::pair; + using KeyTy = std::pair; bool operator==(const KeyTy &key) const { return std::get<0>(key) == lhs && std::get<1>(key) == rhs; @@ -58,13 +58,13 @@ struct SDBMBinaryExprStorage : public SDBMExprStorage { return result; } - SDBMVaryingExpr lhs; + SDBMDirectExpr lhs; SDBMConstantExpr rhs; }; // Storage class for SDBM difference expressions. struct SDBMDiffExprStorage : public SDBMExprStorage { - using KeyTy = std::pair; + using KeyTy = std::pair; bool operator==(const KeyTy &key) const { return std::get<0>(key) == lhs && std::get<1>(key) == rhs; @@ -79,7 +79,7 @@ struct SDBMDiffExprStorage : public SDBMExprStorage { return result; } - SDBMTermExpr lhs; + SDBMDirectExpr lhs; SDBMTermExpr rhs; }; @@ -117,19 +117,19 @@ struct SDBMTermExprStorage : public SDBMExprStorage { // Storage class for SDBM negation expressions. struct SDBMNegExprStorage : public SDBMExprStorage { - using KeyTy = SDBMTermExpr; + using KeyTy = SDBMDirectExpr; - bool operator==(const KeyTy &key) const { return key == dim; } + bool operator==(const KeyTy &key) const { return key == expr; } static SDBMNegExprStorage * construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { auto *result = allocator.allocate(); - result->dim = key; + result->expr = key; result->dialect = key.getDialect(); return result; } - SDBMTermExpr dim; + SDBMDirectExpr expr; }; } // end namespace detail diff --git a/mlir/test/SDBM/sdbm-api-test.cpp b/mlir/test/SDBM/sdbm-api-test.cpp index b8cbaef5c351..d7a432087aec 100644 --- a/mlir/test/SDBM/sdbm-api-test.cpp +++ b/mlir/test/SDBM/sdbm-api-test.cpp @@ -161,13 +161,13 @@ TEST_FUNC(SDBM_StripeTightening) { SmallVector eqs, ineqs; sdbm.getSDBMExpressions(dialect(), ineqs, eqs); - // CHECK: s0 # 3 - d0 + -2 + // CHECK: s0 # 3 + -2 - d0 // CHECK-EMPTY: for (auto ineq : ineqs) ineq.print(llvm::outs() << '\n'); llvm::outs() << "\n"; - // CHECK-DAG: d1 - d0 + -42 + // CHECK-DAG: d1 + -42 - d0 // CHECK-DAG: d0 - s0 # 3 # 5 for (auto eq : eqs) eq.print(llvm::outs() << '\n'); diff --git a/mlir/unittests/SDBM/SDBMTest.cpp b/mlir/unittests/SDBM/SDBMTest.cpp index af44c80167f0..13941cdffd3d 100644 --- a/mlir/unittests/SDBM/SDBMTest.cpp +++ b/mlir/unittests/SDBM/SDBMTest.cpp @@ -76,6 +76,28 @@ TEST(SDBMOperators, AddFolding) { auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0); EXPECT_EQ(inverted, expr); + + // Check that opposite values cancel each other, and that we elide the zero + // constant. + expr = dim(0) + 42; + auto onlyDim = expr - 42; + EXPECT_EQ(onlyDim, dim(0)); + + // Check that we can sink a constant under a negation. + expr = -(dim(0) + 2); + auto negatedSum = (expr + 10).dyn_cast(); + ASSERT_TRUE(negatedSum); + auto sum = negatedSum.getVar().dyn_cast(); + ASSERT_TRUE(sum); + EXPECT_EQ(sum.getRHS().getValue(), -8); + + // Sum with zero is the same as the original expression. + EXPECT_EQ(dim(0) + 0, dim(0)); + + // Sum of opposite differences is zero. + auto diffOfDiffs = + ((dim(0) - dim(1)) + (dim(1) - dim(0))).dyn_cast(); + EXPECT_EQ(diffOfDiffs.getValue(), 0); } TEST(SDBMOperators, Diff) { @@ -101,6 +123,43 @@ TEST(SDBMOperators, DiffFolding) { constantExpr = zero.dyn_cast(); ASSERT_TRUE(constantExpr); EXPECT_EQ(constantExpr.getValue(), 0); + + // Check that the constant terms in difference-of-sums are folded. + // (d0 - 3) - (d1 - 5) = (d0 + 2) - d1 + auto diffOfSums = ((dim(0) - 3) - (dim(1) - 5)).dyn_cast(); + ASSERT_TRUE(diffOfSums); + auto lhs = diffOfSums.getLHS().dyn_cast(); + ASSERT_TRUE(lhs); + EXPECT_EQ(lhs.getLHS(), dim(0)); + EXPECT_EQ(lhs.getRHS().getValue(), 2); + EXPECT_EQ(diffOfSums.getRHS(), dim(1)); + + // Check that identical dimensions with opposite signs cancel each other. + auto cstOnly = ((dim(0) + 42) - dim(0)).dyn_cast(); + ASSERT_TRUE(cstOnly); + EXPECT_EQ(cstOnly.getValue(), 42); + + // Check that identical terms in sum of diffs cancel out. + auto dimOnly = (-dim(0) + (dim(0) - dim(1))); + EXPECT_EQ(dimOnly, -dim(1)); + dimOnly = (dim(0) - dim(1)) + (-dim(0)); + EXPECT_EQ(dimOnly, -dim(1)); + dimOnly = (dim(0) - dim(1)) + dim(1); + EXPECT_EQ(dimOnly, dim(0)); + dimOnly = dim(0) + (dim(1) - dim(0)); + EXPECT_EQ(dimOnly, dim(1)); + + // Top-level zero constant is fine. + cstOnly = (-symb(1) + symb(1)).dyn_cast(); + ASSERT_TRUE(cstOnly); + EXPECT_EQ(cstOnly.getValue(), 0); +} + +TEST(SDBMOperators, Negate) { + auto sum = dim(0) + 3; + auto negated = (-sum).dyn_cast(); + ASSERT_TRUE(negated); + EXPECT_EQ(negated.getVar(), sum); } TEST(SDBMOperators, Stripe) { @@ -324,11 +383,11 @@ TEST(SDBMExpr, AffineRoundTrip) { // Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a // deeper expression tree. - auto diff = SDBMDiffExpr::get(outerStripe, stripe); - auto sum = SDBMSumExpr::get(diff, cst2); - roundtripped = SDBMExpr::tryConvertAffineExpr(sum.getAsAffineExpr()); + auto sum = SDBMSumExpr::get(outerStripe, cst2); + auto diff = SDBMDiffExpr::get(sum, stripe); + roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr()); ASSERT_TRUE(roundtripped.hasValue()); - EXPECT_EQ(roundtripped, static_cast(sum)); + EXPECT_EQ(roundtripped, static_cast(diff)); } TEST(SDBMExpr, MatchStripeMulPattern) {