forked from OSchip/llvm-project
Overhaul the SDBM expression kind hierarchy
Swap the allowed nesting of sum and diff expressions: now a diff expression can contain a sum expression, but only on the left hand side. A difference of two expressions sum must be canonicalized by grouping their constant terms in a single expression. This change of sturcture became possible thanks to the introduction of the "direct" super-kind. It is necessary to enable support of sum expressions on the left hand side of the stripe expression. SDBM expressions are now grouped into the following structure - expression - varying - direct - sum <- (term, constant) - term - symbol - dimension - stripe <- (term, constant) - negation <- (direct) - difference <- (direct, term) - constant The notation <- (...) denotes the types of subexpressions a compound expression can combine. PiperOrigin-RevId: 269337222
This commit is contained in:
parent
e94db619d9
commit
cb3ecb5291
|
@ -68,6 +68,28 @@ class SDBMSymbolExpr;
|
|||
/// not combine in more cases than they do. This choice may be reconsidered in
|
||||
/// the future.
|
||||
///
|
||||
/// SDBM expressions are grouped into the following structure
|
||||
/// - expression
|
||||
/// - varying
|
||||
/// - direct
|
||||
/// - sum <- (term, constant)
|
||||
/// - term
|
||||
/// - symbol
|
||||
/// - dimension
|
||||
/// - stripe <- (term, constant)
|
||||
/// - negation <- (direct)
|
||||
/// - difference <- (direct, term)
|
||||
/// - constant
|
||||
/// The notation <- (...) denotes the types of subexpressions a compound
|
||||
/// expression can combine. The tree of subexpressions essentially imposes the
|
||||
/// following canonicalization rules:
|
||||
/// - constants are always folded;
|
||||
/// - constants can only appear on the RHS of an expression;
|
||||
/// - double negation must be elided;
|
||||
/// - an additive constant term is only allowed in a sum expression, and
|
||||
/// should be sunk into the nearest such expression in the tree;
|
||||
/// - zero constant expression can only appear at the top level.
|
||||
///
|
||||
/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by
|
||||
/// an MLIRContext, and should be used by-value. They are uniqued in the
|
||||
/// MLIRContext and immortal.
|
||||
|
@ -208,39 +230,42 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// SDBM sum expression. LHS is a varying expression and RHS is always a
|
||||
/// constant expression.
|
||||
/// SDBM sum expression. LHS is a term expression and RHS is a constant.
|
||||
class SDBMSumExpr : public SDBMDirectExpr {
|
||||
public:
|
||||
using ImplType = detail::SDBMBinaryExprStorage;
|
||||
using SDBMDirectExpr::SDBMDirectExpr;
|
||||
|
||||
/// Obtain or create a sum expression unique'ed in the given context.
|
||||
static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs);
|
||||
static SDBMSumExpr get(SDBMTermExpr lhs, SDBMConstantExpr rhs);
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
SDBMExprKind kind = expr.getKind();
|
||||
return kind == SDBMExprKind::Add;
|
||||
}
|
||||
|
||||
SDBMVaryingExpr getLHS() const;
|
||||
SDBMTermExpr getLHS() const;
|
||||
SDBMConstantExpr getRHS() const;
|
||||
};
|
||||
|
||||
/// SDBM difference expression. Both LHS and RHS are SDBM term expressions.
|
||||
/// SDBM difference expression. LHS is a direct expression, i.e. it may be a
|
||||
/// sum of a term and a constant. RHS is a term expression. Thus the
|
||||
/// expression (t1 - t2 + C) with term expressions t1,t2 is represented as
|
||||
/// diff(sum(t1, C), t2)
|
||||
/// and it is possible to extract the constant factor without negating it.
|
||||
class SDBMDiffExpr : public SDBMVaryingExpr {
|
||||
public:
|
||||
using ImplType = detail::SDBMDiffExprStorage;
|
||||
using SDBMVaryingExpr::SDBMVaryingExpr;
|
||||
|
||||
/// Obtain or create a difference expression unique'ed in the given context.
|
||||
static SDBMDiffExpr get(SDBMTermExpr lhs, SDBMTermExpr rhs);
|
||||
static SDBMDiffExpr get(SDBMDirectExpr lhs, SDBMTermExpr rhs);
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
return expr.getKind() == SDBMExprKind::Diff;
|
||||
}
|
||||
|
||||
SDBMTermExpr getLHS() const;
|
||||
SDBMDirectExpr getLHS() const;
|
||||
SDBMTermExpr getRHS() const;
|
||||
};
|
||||
|
||||
|
@ -319,13 +344,13 @@ public:
|
|||
using SDBMVaryingExpr::SDBMVaryingExpr;
|
||||
|
||||
/// Obtain or create a negation expression unique'ed in the given context.
|
||||
static SDBMNegExpr get(SDBMTermExpr var);
|
||||
static SDBMNegExpr get(SDBMDirectExpr var);
|
||||
|
||||
static bool isClassFor(const SDBMExpr &expr) {
|
||||
return expr.getKind() == SDBMExprKind::Neg;
|
||||
}
|
||||
|
||||
SDBMTermExpr getVar() const;
|
||||
SDBMDirectExpr getVar() const;
|
||||
};
|
||||
|
||||
/// A visitor class for SDBM expressions. Calls the kind-specific function
|
||||
|
@ -490,22 +515,22 @@ template <> struct DenseMapInfo<mlir::SDBMExpr> {
|
|||
}
|
||||
};
|
||||
|
||||
// SDBMVaryingExpr hash just like pointers.
|
||||
template <> struct DenseMapInfo<mlir::SDBMVaryingExpr> {
|
||||
static mlir::SDBMVaryingExpr getEmptyKey() {
|
||||
// SDBMDirectExpr hash just like pointers.
|
||||
template <> struct DenseMapInfo<mlir::SDBMDirectExpr> {
|
||||
static mlir::SDBMDirectExpr getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::SDBMVaryingExpr(
|
||||
return mlir::SDBMDirectExpr(
|
||||
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
||||
}
|
||||
static mlir::SDBMVaryingExpr getTombstoneKey() {
|
||||
static mlir::SDBMDirectExpr getTombstoneKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return mlir::SDBMVaryingExpr(
|
||||
return mlir::SDBMDirectExpr(
|
||||
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
||||
}
|
||||
static unsigned getHashValue(mlir::SDBMVaryingExpr expr) {
|
||||
static unsigned getHashValue(mlir::SDBMDirectExpr expr) {
|
||||
return expr.hash_value();
|
||||
}
|
||||
static bool isEqual(mlir::SDBMVaryingExpr lhs, mlir::SDBMVaryingExpr rhs) {
|
||||
static bool isEqual(mlir::SDBMDirectExpr lhs, mlir::SDBMDirectExpr rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -161,14 +161,14 @@ void SDBMExpr::print(raw_ostream &os) const {
|
|||
Printer(raw_ostream &ostream) : prn(ostream) {}
|
||||
|
||||
void visitSum(SDBMSumExpr expr) {
|
||||
visitVarying(expr.getLHS());
|
||||
visit(expr.getLHS());
|
||||
prn << " + ";
|
||||
visitConstant(expr.getRHS());
|
||||
visit(expr.getRHS());
|
||||
}
|
||||
void visitDiff(SDBMDiffExpr expr) {
|
||||
visitTerm(expr.getLHS());
|
||||
visit(expr.getLHS());
|
||||
prn << " - ";
|
||||
visitTerm(expr.getRHS());
|
||||
visit(expr.getRHS());
|
||||
}
|
||||
void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
|
||||
void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
|
||||
|
@ -178,8 +178,13 @@ void SDBMExpr::print(raw_ostream &os) const {
|
|||
visitConstant(expr.getStripeFactor());
|
||||
}
|
||||
void visitNeg(SDBMNegExpr expr) {
|
||||
bool isSum = expr.getVar().isa<SDBMSumExpr>();
|
||||
prn << '-';
|
||||
visitTerm(expr.getVar());
|
||||
if (isSum)
|
||||
prn << '(';
|
||||
visit(expr.getVar());
|
||||
if (isSum)
|
||||
prn << ')';
|
||||
}
|
||||
void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
|
||||
|
||||
|
@ -199,7 +204,7 @@ namespace {
|
|||
struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
|
||||
// Any term expression is wrapped into a negation expression.
|
||||
// -(x) = -x
|
||||
SDBMExpr visitTerm(SDBMTermExpr expr) { return SDBMNegExpr::get(expr); }
|
||||
SDBMExpr visitDirect(SDBMDirectExpr expr) { return SDBMNegExpr::get(expr); }
|
||||
// A negation expression is unwrapped.
|
||||
// -(-x) = x
|
||||
SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
|
||||
|
@ -207,15 +212,20 @@ struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
|
|||
SDBMExpr visitConstant(SDBMConstantExpr expr) {
|
||||
return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
|
||||
}
|
||||
// Both terms of the sum are negated recursively.
|
||||
SDBMExpr visitSum(SDBMSumExpr expr) {
|
||||
return SDBMSumExpr::get(visit(expr.getLHS()).cast<SDBMVaryingExpr>(),
|
||||
visit(expr.getRHS()).cast<SDBMConstantExpr>());
|
||||
}
|
||||
// Terms of a difference are interchanged.
|
||||
// -(x - y) = y - x
|
||||
|
||||
// Terms of a difference are interchanged. Since only the LHS of a diff
|
||||
// expression is allowed to be a sum with a constant, we need to recreate the
|
||||
// sum with the negated value:
|
||||
// -((x + C) - y) = (y - C) - x.
|
||||
SDBMExpr visitDiff(SDBMDiffExpr expr) {
|
||||
return SDBMDiffExpr::get(expr.getRHS(), expr.getLHS());
|
||||
// If the LHS is just a term, we can do straightforward interchange.
|
||||
if (auto term = expr.getLHS().dyn_cast<SDBMTermExpr>())
|
||||
return SDBMDiffExpr::get(expr.getRHS(), term);
|
||||
|
||||
auto sum = expr.getLHS().cast<SDBMSumExpr>();
|
||||
auto cst = visitConstant(sum.getRHS()).cast<SDBMConstantExpr>();
|
||||
return SDBMDiffExpr::get(SDBMSumExpr::get(expr.getRHS(), cst),
|
||||
sum.getLHS());
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -226,7 +236,7 @@ SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); }
|
|||
// SDBMSumExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
|
||||
SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
|
||||
assert(lhs && "expected SDBM variable expression");
|
||||
assert(rhs && "expected SDBM constant");
|
||||
|
||||
|
@ -242,8 +252,8 @@ SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
|
|||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
|
||||
}
|
||||
|
||||
SDBMVaryingExpr SDBMSumExpr::getLHS() const {
|
||||
return static_cast<ImplType *>(impl)->lhs;
|
||||
SDBMTermExpr SDBMSumExpr::getLHS() const {
|
||||
return static_cast<ImplType *>(impl)->lhs.cast<SDBMTermExpr>();
|
||||
}
|
||||
|
||||
SDBMConstantExpr SDBMSumExpr::getRHS() const {
|
||||
|
@ -289,24 +299,126 @@ AffineExpr SDBMExpr::getAsAffineExpr() const {
|
|||
return converter.visit(*this);
|
||||
}
|
||||
|
||||
// Given a direct expression `expr`, add the given constant to it and pass the
|
||||
// resulting expression to `builder` before returning its result. If the
|
||||
// expression is already a sum expression, update its constant and extract the
|
||||
// LHS if the constant becomes zero. Otherwise, construct a sum expression.
|
||||
template <typename Result>
|
||||
Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant, bool negated,
|
||||
llvm::function_ref<Result(SDBMDirectExpr)> builder) {
|
||||
SDBMDialect *dialect = expr.getDialect();
|
||||
if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
|
||||
if (negated)
|
||||
constant = sumExpr.getRHS().getValue() - constant;
|
||||
else
|
||||
constant += sumExpr.getRHS().getValue();
|
||||
|
||||
if (constant != 0) {
|
||||
auto sum = SDBMSumExpr::get(sumExpr.getLHS(),
|
||||
SDBMConstantExpr::get(dialect, constant));
|
||||
return builder(sum);
|
||||
} else {
|
||||
return builder(sumExpr.getLHS());
|
||||
}
|
||||
}
|
||||
if (constant != 0)
|
||||
return builder(SDBMSumExpr::get(
|
||||
expr.cast<SDBMTermExpr>(),
|
||||
SDBMConstantExpr::get(dialect, negated ? -constant : constant)));
|
||||
return expr;
|
||||
}
|
||||
|
||||
// Construct an expression lhs + constant while maintaining the canonical form
|
||||
// of the SDBM expressions, in particular sink the constant expression to the
|
||||
// nearest sum expression in the left subtree of the expresison tree.
|
||||
static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) {
|
||||
if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
|
||||
return addConstantAndSink<SDBMExpr>(
|
||||
lhsDiff.getLHS(), constant, /*negated=*/false,
|
||||
[lhsDiff](SDBMDirectExpr e) {
|
||||
return SDBMDiffExpr::get(e, lhsDiff.getRHS());
|
||||
});
|
||||
if (auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>())
|
||||
return addConstantAndSink<SDBMExpr>(
|
||||
lhsNeg.getVar(), constant, /*negated=*/true,
|
||||
[](SDBMDirectExpr e) { return SDBMNegExpr::get(e); });
|
||||
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>())
|
||||
return addConstantAndSink<SDBMExpr>(lhsSum, constant, /*negated=*/false,
|
||||
[](SDBMDirectExpr e) { return e; });
|
||||
if (constant != 0)
|
||||
return SDBMSumExpr::get(lhs.cast<SDBMTermExpr>(),
|
||||
SDBMConstantExpr::get(lhs.getDialect(), constant));
|
||||
return lhs;
|
||||
}
|
||||
|
||||
// Build a difference expression given a direct expression and a negation
|
||||
// expression.
|
||||
static SDBMExpr buildDiffExpr(SDBMDirectExpr lhs, SDBMNegExpr rhs) {
|
||||
SDBMTermExpr lhsTerm, rhsTerm;
|
||||
int lhsConstant = 0;
|
||||
int64_t rhsConstant = 0;
|
||||
|
||||
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
|
||||
lhsConstant = lhsSum.getRHS().getValue();
|
||||
lhsTerm = lhsSum.getLHS();
|
||||
} else {
|
||||
lhsTerm = lhs.cast<SDBMTermExpr>();
|
||||
}
|
||||
|
||||
if (auto rhsNegatedSum = rhs.getVar().dyn_cast<SDBMSumExpr>()) {
|
||||
rhsTerm = rhsNegatedSum.getLHS();
|
||||
rhsConstant = rhsNegatedSum.getRHS().getValue();
|
||||
} else {
|
||||
rhsTerm = rhs.getVar().cast<SDBMTermExpr>();
|
||||
}
|
||||
|
||||
// Fold (x + C) - (x + D) = C - D.
|
||||
if (lhsTerm == rhsTerm)
|
||||
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant - rhsConstant);
|
||||
|
||||
return SDBMDiffExpr::get(
|
||||
addConstantAndSink<SDBMDirectExpr>(lhs, -rhsConstant, /*negated=*/false,
|
||||
[](SDBMDirectExpr e) { return e; }),
|
||||
rhsTerm);
|
||||
}
|
||||
|
||||
// Try folding an expression (lhs + rhs) where at least one of the operands
|
||||
// contains a negated variable, i.e. is a negation or a difference expression.
|
||||
static SDBMExpr foldSumDiff(SDBMExpr lhs, SDBMExpr rhs) {
|
||||
// If exactly one of LHS, RHS is a negation expression, we can construct
|
||||
// a difference expression, which is a special kind in SDBM.
|
||||
auto lhsDirect = lhs.dyn_cast<SDBMDirectExpr>();
|
||||
auto rhsDirect = rhs.dyn_cast<SDBMDirectExpr>();
|
||||
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
|
||||
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
|
||||
|
||||
if (lhsDirect && rhsNeg)
|
||||
return buildDiffExpr(lhsDirect, rhsNeg);
|
||||
if (lhsNeg && rhsDirect)
|
||||
return buildDiffExpr(rhsDirect, lhsNeg);
|
||||
|
||||
// If a subexpression appears in a diff expression on the LHS(RHS) of a
|
||||
// sum expression where it also appears on the RHS(LHS) with the opposite
|
||||
// sign, we can simplify it away and obtain the SDBM form.
|
||||
// x - (x - C) = -(x - C) + x = C
|
||||
// (x - C) - x = -x + (x - C) = -C
|
||||
auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
|
||||
auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
|
||||
if (lhsNeg && rhsDiff && lhsNeg.getVar() == rhsDiff.getLHS())
|
||||
return SDBMNegExpr::get(rhsDiff.getRHS());
|
||||
if (lhsDirect && rhsDiff && lhsDirect == rhsDiff.getRHS())
|
||||
return rhsDiff.getLHS();
|
||||
if (lhsDiff && rhsNeg && lhsDiff.getLHS() == rhsNeg.getVar())
|
||||
return SDBMNegExpr::get(lhsDiff.getRHS());
|
||||
if (rhsDirect && lhsDiff && rhsDirect == lhsDiff.getRHS())
|
||||
return lhsDiff.getLHS();
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
||||
struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
|
||||
SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
|
||||
// Attempt to recover a stripe expression. Because AffineExprs don't have
|
||||
// a first-class difference kind, we check for both x + -1 * (x mod C) and
|
||||
// -1 * (x mod C) + x cases.
|
||||
AffineExprMatcher x, C, m;
|
||||
AffineExprMatcher pattern1 = ((x % C) * m) + x;
|
||||
AffineExprMatcher pattern2 = x + ((x % C) * m);
|
||||
if ((pattern1.match(expr) && m.getMatchedConstantValue() == -1) ||
|
||||
(pattern2.match(expr) && m.getMatchedConstantValue() == -1)) {
|
||||
if (auto convertedLHS = visit(x.matched())) {
|
||||
// TODO(ntv): return convertedLHS.stripe(C);
|
||||
return SDBMStripeExpr::get(
|
||||
convertedLHS.cast<SDBMTermExpr>(),
|
||||
visit(C.matched()).cast<SDBMConstantExpr>());
|
||||
}
|
||||
}
|
||||
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
|
||||
if (!lhs || !rhs)
|
||||
return {};
|
||||
|
@ -314,29 +426,22 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
|||
// In a "add" AffineExpr, the constant always appears on the right. If
|
||||
// there were two constants, they would have been folded away.
|
||||
assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
|
||||
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
||||
|
||||
// SDBM accepts LHS variables and RHS constants in a sum.
|
||||
auto lhsVar = lhs.dyn_cast<SDBMVaryingExpr>();
|
||||
auto rhsVar = rhs.dyn_cast<SDBMVaryingExpr>();
|
||||
if (rhsConstant && lhsVar)
|
||||
return SDBMSumExpr::get(lhsVar, rhsConstant);
|
||||
// If RHS is a constant, we can always extend the SDBM expression to
|
||||
// include it by sinking the constant into the nearest sum expresion.
|
||||
if (auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>()) {
|
||||
assert(!lhs.isa<SDBMSumExpr>() && "unexpected non-canonicalized sum");
|
||||
|
||||
// The sum of a negated variable and a non-negated variable is a
|
||||
// difference, supported as a special kind in SDBM. Because AffineExprs
|
||||
// don't have first-class difference kind, check both LHS and RHS for
|
||||
// negation.
|
||||
auto lhsPos = lhs.dyn_cast<SDBMTermExpr>();
|
||||
auto rhsPos = rhs.dyn_cast<SDBMTermExpr>();
|
||||
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
|
||||
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
|
||||
if (lhsNeg && rhsVar)
|
||||
return SDBMDiffExpr::get(rhsPos, lhsNeg.getVar());
|
||||
if (rhsNeg && lhsVar)
|
||||
return SDBMDiffExpr::get(lhsPos, rhsNeg.getVar());
|
||||
int64_t constant = rhsConstant.getValue();
|
||||
auto varying = lhs.dyn_cast<SDBMVaryingExpr>();
|
||||
assert(varying && "unexpected uncanonicalized sum of constants");
|
||||
return addConstant(varying, constant);
|
||||
}
|
||||
|
||||
// Other cases don't fit into SDBM.
|
||||
return {};
|
||||
// Try building a difference expression if one of the values is negated,
|
||||
// or check if a difference on either hand side cancels out the outer term
|
||||
// so as to remain correct within SDBM. Return null otherwise.
|
||||
return foldSumDiff(lhs, rhs);
|
||||
}
|
||||
|
||||
SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
|
||||
|
@ -367,9 +472,13 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
|||
|
||||
// The only supported "multiplication" expression is an SDBM is dimension
|
||||
// negation, that is a product of dimension and constant -1.
|
||||
auto lhsVar = lhs.dyn_cast<SDBMTermExpr>();
|
||||
if (lhsVar && rhsConstant.getValue() == -1)
|
||||
if (rhsConstant.getValue() != -1)
|
||||
return {};
|
||||
|
||||
if (auto lhsVar = lhs.dyn_cast<SDBMTermExpr>())
|
||||
return SDBMNegExpr::get(lhsVar);
|
||||
if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
|
||||
return SDBMNegator().visitDiff(lhsDiff);
|
||||
|
||||
// Other multiplications are not allowed in SDBM.
|
||||
return {};
|
||||
|
@ -383,7 +492,7 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
|||
// 'mod' can only be converted to SDBM if its LHS is a variable
|
||||
// and its RHS is a constant. Then it `x mod c = x - x stripe c`.
|
||||
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
||||
auto lhsVar = rhs.dyn_cast<SDBMTermExpr>();
|
||||
auto lhsVar = lhs.dyn_cast<SDBMTermExpr>();
|
||||
if (!lhsVar || !rhsConstant)
|
||||
return {};
|
||||
return SDBMDiffExpr::get(lhsVar,
|
||||
|
@ -418,7 +527,7 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
|||
// SDBMDiffExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SDBMDiffExpr SDBMDiffExpr::get(SDBMTermExpr lhs, SDBMTermExpr rhs) {
|
||||
SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
|
||||
assert(lhs && "expected SDBM dimension");
|
||||
assert(rhs && "expected SDBM dimension");
|
||||
|
||||
|
@ -427,7 +536,7 @@ SDBMDiffExpr SDBMDiffExpr::get(SDBMTermExpr lhs, SDBMTermExpr rhs) {
|
|||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
|
||||
}
|
||||
|
||||
SDBMTermExpr SDBMDiffExpr::getLHS() const {
|
||||
SDBMDirectExpr SDBMDiffExpr::getLHS() const {
|
||||
return static_cast<ImplType *>(impl)->lhs;
|
||||
}
|
||||
|
||||
|
@ -526,31 +635,36 @@ int64_t SDBMConstantExpr::getValue() const {
|
|||
// SDBMNegExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SDBMNegExpr SDBMNegExpr::get(SDBMTermExpr var) {
|
||||
assert(var && "expected non-null SDBM variable expression");
|
||||
SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
|
||||
assert(var && "expected non-null SDBM direct expression");
|
||||
|
||||
StorageUniquer &uniquer = var.getDialect()->getUniquer();
|
||||
return uniquer.get<detail::SDBMNegExprStorage>(
|
||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
|
||||
}
|
||||
|
||||
SDBMTermExpr SDBMNegExpr::getVar() const {
|
||||
return static_cast<ImplType *>(impl)->dim;
|
||||
SDBMDirectExpr SDBMNegExpr::getVar() const {
|
||||
return static_cast<ImplType *>(impl)->expr;
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace ops_assertions {
|
||||
|
||||
SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) {
|
||||
// If one of the operands is a negation, take a difference rather than a sum.
|
||||
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
|
||||
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
|
||||
assert(!(lhsNeg && rhsNeg) && "a sum of negated expressions is a negation of "
|
||||
"a sum of variables and not a correct SDBM");
|
||||
if (lhsNeg)
|
||||
return rhs - lhsNeg.getVar();
|
||||
if (rhsNeg)
|
||||
return lhs - rhsNeg.getVar();
|
||||
if (auto folded = foldSumDiff(lhs, rhs))
|
||||
return folded;
|
||||
assert(!(lhs.isa<SDBMNegExpr>() && rhs.isa<SDBMNegExpr>()) &&
|
||||
"a sum of negated expressions is a negation of a sum of variables and "
|
||||
"not a correct SDBM");
|
||||
|
||||
// Fold (x - y) + (y - x) = 0.
|
||||
auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
|
||||
auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
|
||||
if (lhsDiff && rhsDiff) {
|
||||
if (lhsDiff.getLHS() == rhsDiff.getRHS() &&
|
||||
lhsDiff.getRHS() == rhsDiff.getLHS())
|
||||
return SDBMConstantExpr::get(lhs.getDialect(), 0);
|
||||
}
|
||||
|
||||
// If LHS is a constant and RHS is not, swap the order to get into a supported
|
||||
// sum case. From now on, RHS must be a constant.
|
||||
|
@ -562,26 +676,11 @@ SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) {
|
|||
}
|
||||
assert(rhsConstant && "at least one operand must be a constant");
|
||||
|
||||
// If LHS is another sum, first compute the sum of its variable
|
||||
// part with the other argument and then add the constant part to enable
|
||||
// constant folding (the variable part may, e.g., be a negation that requires
|
||||
// to enter this function again).
|
||||
auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
|
||||
if (lhsSum)
|
||||
return lhsSum.getLHS() +
|
||||
(lhsSum.getRHS().getValue() + rhsConstant.getValue());
|
||||
|
||||
// Constant-fold if LHS is a constant.
|
||||
// Constant-fold if LHS is also a constant.
|
||||
if (lhsConstant)
|
||||
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
|
||||
rhsConstant.getValue());
|
||||
|
||||
// Fold x + 0 == x.
|
||||
if (rhsConstant.getValue() == 0)
|
||||
return lhs;
|
||||
|
||||
return SDBMSumExpr::get(lhs.cast<SDBMVaryingExpr>(),
|
||||
rhs.cast<SDBMConstantExpr>());
|
||||
return addConstant(lhs.cast<SDBMVaryingExpr>(), rhsConstant.getValue());
|
||||
}
|
||||
|
||||
SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) {
|
||||
|
@ -608,25 +707,7 @@ SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) {
|
|||
if (lhsConstant)
|
||||
return -rhs + lhsConstant;
|
||||
|
||||
// Hoist constant factors outside the difference if any of sides is a sum:
|
||||
// (x + A) - (y - B) == x - y + (A - B).
|
||||
// If either LHS or RHS is a sum, collect the constant values separately and
|
||||
// update LHS and RHS to point to the variable part of the sum.
|
||||
auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
|
||||
auto rhsSum = rhs.dyn_cast<SDBMSumExpr>();
|
||||
int64_t value = 0;
|
||||
if (lhsSum) {
|
||||
value += lhsSum.getRHS().getValue();
|
||||
lhs = lhsSum.getLHS();
|
||||
}
|
||||
if (rhsSum) {
|
||||
value -= rhsSum.getRHS().getValue();
|
||||
rhs = rhsSum.getLHS();
|
||||
}
|
||||
|
||||
// This calls into operator+ for futher simplification in case value == 0.
|
||||
return SDBMDiffExpr::get(lhs.cast<SDBMTermExpr>(), rhs.cast<SDBMTermExpr>()) +
|
||||
value;
|
||||
return buildDiffExpr(lhs.cast<SDBMDirectExpr>(), (-rhs).cast<SDBMNegExpr>());
|
||||
}
|
||||
|
||||
SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) {
|
||||
|
|
|
@ -43,7 +43,7 @@ struct SDBMExprStorage : public StorageUniquer::BaseStorage {
|
|||
|
||||
// Storage class for SDBM sum and stripe expressions.
|
||||
struct SDBMBinaryExprStorage : public SDBMExprStorage {
|
||||
using KeyTy = std::pair<SDBMVaryingExpr, SDBMConstantExpr>;
|
||||
using KeyTy = std::pair<SDBMDirectExpr, SDBMConstantExpr>;
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
|
||||
|
@ -58,13 +58,13 @@ struct SDBMBinaryExprStorage : public SDBMExprStorage {
|
|||
return result;
|
||||
}
|
||||
|
||||
SDBMVaryingExpr lhs;
|
||||
SDBMDirectExpr lhs;
|
||||
SDBMConstantExpr rhs;
|
||||
};
|
||||
|
||||
// Storage class for SDBM difference expressions.
|
||||
struct SDBMDiffExprStorage : public SDBMExprStorage {
|
||||
using KeyTy = std::pair<SDBMTermExpr, SDBMTermExpr>;
|
||||
using KeyTy = std::pair<SDBMDirectExpr, SDBMTermExpr>;
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
|
||||
|
@ -79,7 +79,7 @@ struct SDBMDiffExprStorage : public SDBMExprStorage {
|
|||
return result;
|
||||
}
|
||||
|
||||
SDBMTermExpr lhs;
|
||||
SDBMDirectExpr lhs;
|
||||
SDBMTermExpr rhs;
|
||||
};
|
||||
|
||||
|
@ -117,19 +117,19 @@ struct SDBMTermExprStorage : public SDBMExprStorage {
|
|||
|
||||
// Storage class for SDBM negation expressions.
|
||||
struct SDBMNegExprStorage : public SDBMExprStorage {
|
||||
using KeyTy = SDBMTermExpr;
|
||||
using KeyTy = SDBMDirectExpr;
|
||||
|
||||
bool operator==(const KeyTy &key) const { return key == dim; }
|
||||
bool operator==(const KeyTy &key) const { return key == expr; }
|
||||
|
||||
static SDBMNegExprStorage *
|
||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||
auto *result = allocator.allocate<SDBMNegExprStorage>();
|
||||
result->dim = key;
|
||||
result->expr = key;
|
||||
result->dialect = key.getDialect();
|
||||
return result;
|
||||
}
|
||||
|
||||
SDBMTermExpr dim;
|
||||
SDBMDirectExpr expr;
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
|
|
|
@ -161,13 +161,13 @@ TEST_FUNC(SDBM_StripeTightening) {
|
|||
|
||||
SmallVector<SDBMExpr, 4> eqs, ineqs;
|
||||
sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
|
||||
// CHECK: s0 # 3 - d0 + -2
|
||||
// CHECK: s0 # 3 + -2 - d0
|
||||
// CHECK-EMPTY:
|
||||
for (auto ineq : ineqs)
|
||||
ineq.print(llvm::outs() << '\n');
|
||||
llvm::outs() << "\n";
|
||||
|
||||
// CHECK-DAG: d1 - d0 + -42
|
||||
// CHECK-DAG: d1 + -42 - d0
|
||||
// CHECK-DAG: d0 - s0 # 3 # 5
|
||||
for (auto eq : eqs)
|
||||
eq.print(llvm::outs() << '\n');
|
||||
|
|
|
@ -76,6 +76,28 @@ TEST(SDBMOperators, AddFolding) {
|
|||
|
||||
auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0);
|
||||
EXPECT_EQ(inverted, expr);
|
||||
|
||||
// Check that opposite values cancel each other, and that we elide the zero
|
||||
// constant.
|
||||
expr = dim(0) + 42;
|
||||
auto onlyDim = expr - 42;
|
||||
EXPECT_EQ(onlyDim, dim(0));
|
||||
|
||||
// Check that we can sink a constant under a negation.
|
||||
expr = -(dim(0) + 2);
|
||||
auto negatedSum = (expr + 10).dyn_cast<SDBMNegExpr>();
|
||||
ASSERT_TRUE(negatedSum);
|
||||
auto sum = negatedSum.getVar().dyn_cast<SDBMSumExpr>();
|
||||
ASSERT_TRUE(sum);
|
||||
EXPECT_EQ(sum.getRHS().getValue(), -8);
|
||||
|
||||
// Sum with zero is the same as the original expression.
|
||||
EXPECT_EQ(dim(0) + 0, dim(0));
|
||||
|
||||
// Sum of opposite differences is zero.
|
||||
auto diffOfDiffs =
|
||||
((dim(0) - dim(1)) + (dim(1) - dim(0))).dyn_cast<SDBMConstantExpr>();
|
||||
EXPECT_EQ(diffOfDiffs.getValue(), 0);
|
||||
}
|
||||
|
||||
TEST(SDBMOperators, Diff) {
|
||||
|
@ -101,6 +123,43 @@ TEST(SDBMOperators, DiffFolding) {
|
|||
constantExpr = zero.dyn_cast<SDBMConstantExpr>();
|
||||
ASSERT_TRUE(constantExpr);
|
||||
EXPECT_EQ(constantExpr.getValue(), 0);
|
||||
|
||||
// Check that the constant terms in difference-of-sums are folded.
|
||||
// (d0 - 3) - (d1 - 5) = (d0 + 2) - d1
|
||||
auto diffOfSums = ((dim(0) - 3) - (dim(1) - 5)).dyn_cast<SDBMDiffExpr>();
|
||||
ASSERT_TRUE(diffOfSums);
|
||||
auto lhs = diffOfSums.getLHS().dyn_cast<SDBMSumExpr>();
|
||||
ASSERT_TRUE(lhs);
|
||||
EXPECT_EQ(lhs.getLHS(), dim(0));
|
||||
EXPECT_EQ(lhs.getRHS().getValue(), 2);
|
||||
EXPECT_EQ(diffOfSums.getRHS(), dim(1));
|
||||
|
||||
// Check that identical dimensions with opposite signs cancel each other.
|
||||
auto cstOnly = ((dim(0) + 42) - dim(0)).dyn_cast<SDBMConstantExpr>();
|
||||
ASSERT_TRUE(cstOnly);
|
||||
EXPECT_EQ(cstOnly.getValue(), 42);
|
||||
|
||||
// Check that identical terms in sum of diffs cancel out.
|
||||
auto dimOnly = (-dim(0) + (dim(0) - dim(1)));
|
||||
EXPECT_EQ(dimOnly, -dim(1));
|
||||
dimOnly = (dim(0) - dim(1)) + (-dim(0));
|
||||
EXPECT_EQ(dimOnly, -dim(1));
|
||||
dimOnly = (dim(0) - dim(1)) + dim(1);
|
||||
EXPECT_EQ(dimOnly, dim(0));
|
||||
dimOnly = dim(0) + (dim(1) - dim(0));
|
||||
EXPECT_EQ(dimOnly, dim(1));
|
||||
|
||||
// Top-level zero constant is fine.
|
||||
cstOnly = (-symb(1) + symb(1)).dyn_cast<SDBMConstantExpr>();
|
||||
ASSERT_TRUE(cstOnly);
|
||||
EXPECT_EQ(cstOnly.getValue(), 0);
|
||||
}
|
||||
|
||||
TEST(SDBMOperators, Negate) {
|
||||
auto sum = dim(0) + 3;
|
||||
auto negated = (-sum).dyn_cast<SDBMNegExpr>();
|
||||
ASSERT_TRUE(negated);
|
||||
EXPECT_EQ(negated.getVar(), sum);
|
||||
}
|
||||
|
||||
TEST(SDBMOperators, Stripe) {
|
||||
|
@ -324,11 +383,11 @@ TEST(SDBMExpr, AffineRoundTrip) {
|
|||
|
||||
// Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a
|
||||
// deeper expression tree.
|
||||
auto diff = SDBMDiffExpr::get(outerStripe, stripe);
|
||||
auto sum = SDBMSumExpr::get(diff, cst2);
|
||||
roundtripped = SDBMExpr::tryConvertAffineExpr(sum.getAsAffineExpr());
|
||||
auto sum = SDBMSumExpr::get(outerStripe, cst2);
|
||||
auto diff = SDBMDiffExpr::get(sum, stripe);
|
||||
roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
|
||||
ASSERT_TRUE(roundtripped.hasValue());
|
||||
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(sum));
|
||||
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
|
||||
}
|
||||
|
||||
TEST(SDBMExpr, MatchStripeMulPattern) {
|
||||
|
|
Loading…
Reference in New Issue