forked from OSchip/llvm-project
Overhaul the SDBM expression kind hierarchy
Swap the allowed nesting of sum and diff expressions: now a diff expression can contain a sum expression, but only on the left hand side. A difference of two expressions sum must be canonicalized by grouping their constant terms in a single expression. This change of sturcture became possible thanks to the introduction of the "direct" super-kind. It is necessary to enable support of sum expressions on the left hand side of the stripe expression. SDBM expressions are now grouped into the following structure - expression - varying - direct - sum <- (term, constant) - term - symbol - dimension - stripe <- (term, constant) - negation <- (direct) - difference <- (direct, term) - constant The notation <- (...) denotes the types of subexpressions a compound expression can combine. PiperOrigin-RevId: 269337222
This commit is contained in:
parent
e94db619d9
commit
cb3ecb5291
|
@ -68,6 +68,28 @@ class SDBMSymbolExpr;
|
||||||
/// not combine in more cases than they do. This choice may be reconsidered in
|
/// not combine in more cases than they do. This choice may be reconsidered in
|
||||||
/// the future.
|
/// the future.
|
||||||
///
|
///
|
||||||
|
/// SDBM expressions are grouped into the following structure
|
||||||
|
/// - expression
|
||||||
|
/// - varying
|
||||||
|
/// - direct
|
||||||
|
/// - sum <- (term, constant)
|
||||||
|
/// - term
|
||||||
|
/// - symbol
|
||||||
|
/// - dimension
|
||||||
|
/// - stripe <- (term, constant)
|
||||||
|
/// - negation <- (direct)
|
||||||
|
/// - difference <- (direct, term)
|
||||||
|
/// - constant
|
||||||
|
/// The notation <- (...) denotes the types of subexpressions a compound
|
||||||
|
/// expression can combine. The tree of subexpressions essentially imposes the
|
||||||
|
/// following canonicalization rules:
|
||||||
|
/// - constants are always folded;
|
||||||
|
/// - constants can only appear on the RHS of an expression;
|
||||||
|
/// - double negation must be elided;
|
||||||
|
/// - an additive constant term is only allowed in a sum expression, and
|
||||||
|
/// should be sunk into the nearest such expression in the tree;
|
||||||
|
/// - zero constant expression can only appear at the top level.
|
||||||
|
///
|
||||||
/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by
|
/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by
|
||||||
/// an MLIRContext, and should be used by-value. They are uniqued in the
|
/// an MLIRContext, and should be used by-value. They are uniqued in the
|
||||||
/// MLIRContext and immortal.
|
/// MLIRContext and immortal.
|
||||||
|
@ -208,39 +230,42 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// SDBM sum expression. LHS is a varying expression and RHS is always a
|
/// SDBM sum expression. LHS is a term expression and RHS is a constant.
|
||||||
/// constant expression.
|
|
||||||
class SDBMSumExpr : public SDBMDirectExpr {
|
class SDBMSumExpr : public SDBMDirectExpr {
|
||||||
public:
|
public:
|
||||||
using ImplType = detail::SDBMBinaryExprStorage;
|
using ImplType = detail::SDBMBinaryExprStorage;
|
||||||
using SDBMDirectExpr::SDBMDirectExpr;
|
using SDBMDirectExpr::SDBMDirectExpr;
|
||||||
|
|
||||||
/// Obtain or create a sum expression unique'ed in the given context.
|
/// Obtain or create a sum expression unique'ed in the given context.
|
||||||
static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs);
|
static SDBMSumExpr get(SDBMTermExpr lhs, SDBMConstantExpr rhs);
|
||||||
|
|
||||||
static bool isClassFor(const SDBMExpr &expr) {
|
static bool isClassFor(const SDBMExpr &expr) {
|
||||||
SDBMExprKind kind = expr.getKind();
|
SDBMExprKind kind = expr.getKind();
|
||||||
return kind == SDBMExprKind::Add;
|
return kind == SDBMExprKind::Add;
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMVaryingExpr getLHS() const;
|
SDBMTermExpr getLHS() const;
|
||||||
SDBMConstantExpr getRHS() const;
|
SDBMConstantExpr getRHS() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// SDBM difference expression. Both LHS and RHS are SDBM term expressions.
|
/// SDBM difference expression. LHS is a direct expression, i.e. it may be a
|
||||||
|
/// sum of a term and a constant. RHS is a term expression. Thus the
|
||||||
|
/// expression (t1 - t2 + C) with term expressions t1,t2 is represented as
|
||||||
|
/// diff(sum(t1, C), t2)
|
||||||
|
/// and it is possible to extract the constant factor without negating it.
|
||||||
class SDBMDiffExpr : public SDBMVaryingExpr {
|
class SDBMDiffExpr : public SDBMVaryingExpr {
|
||||||
public:
|
public:
|
||||||
using ImplType = detail::SDBMDiffExprStorage;
|
using ImplType = detail::SDBMDiffExprStorage;
|
||||||
using SDBMVaryingExpr::SDBMVaryingExpr;
|
using SDBMVaryingExpr::SDBMVaryingExpr;
|
||||||
|
|
||||||
/// Obtain or create a difference expression unique'ed in the given context.
|
/// Obtain or create a difference expression unique'ed in the given context.
|
||||||
static SDBMDiffExpr get(SDBMTermExpr lhs, SDBMTermExpr rhs);
|
static SDBMDiffExpr get(SDBMDirectExpr lhs, SDBMTermExpr rhs);
|
||||||
|
|
||||||
static bool isClassFor(const SDBMExpr &expr) {
|
static bool isClassFor(const SDBMExpr &expr) {
|
||||||
return expr.getKind() == SDBMExprKind::Diff;
|
return expr.getKind() == SDBMExprKind::Diff;
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMTermExpr getLHS() const;
|
SDBMDirectExpr getLHS() const;
|
||||||
SDBMTermExpr getRHS() const;
|
SDBMTermExpr getRHS() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -319,13 +344,13 @@ public:
|
||||||
using SDBMVaryingExpr::SDBMVaryingExpr;
|
using SDBMVaryingExpr::SDBMVaryingExpr;
|
||||||
|
|
||||||
/// Obtain or create a negation expression unique'ed in the given context.
|
/// Obtain or create a negation expression unique'ed in the given context.
|
||||||
static SDBMNegExpr get(SDBMTermExpr var);
|
static SDBMNegExpr get(SDBMDirectExpr var);
|
||||||
|
|
||||||
static bool isClassFor(const SDBMExpr &expr) {
|
static bool isClassFor(const SDBMExpr &expr) {
|
||||||
return expr.getKind() == SDBMExprKind::Neg;
|
return expr.getKind() == SDBMExprKind::Neg;
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMTermExpr getVar() const;
|
SDBMDirectExpr getVar() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A visitor class for SDBM expressions. Calls the kind-specific function
|
/// A visitor class for SDBM expressions. Calls the kind-specific function
|
||||||
|
@ -490,22 +515,22 @@ template <> struct DenseMapInfo<mlir::SDBMExpr> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// SDBMVaryingExpr hash just like pointers.
|
// SDBMDirectExpr hash just like pointers.
|
||||||
template <> struct DenseMapInfo<mlir::SDBMVaryingExpr> {
|
template <> struct DenseMapInfo<mlir::SDBMDirectExpr> {
|
||||||
static mlir::SDBMVaryingExpr getEmptyKey() {
|
static mlir::SDBMDirectExpr getEmptyKey() {
|
||||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||||
return mlir::SDBMVaryingExpr(
|
return mlir::SDBMDirectExpr(
|
||||||
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
||||||
}
|
}
|
||||||
static mlir::SDBMVaryingExpr getTombstoneKey() {
|
static mlir::SDBMDirectExpr getTombstoneKey() {
|
||||||
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||||
return mlir::SDBMVaryingExpr(
|
return mlir::SDBMDirectExpr(
|
||||||
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
|
||||||
}
|
}
|
||||||
static unsigned getHashValue(mlir::SDBMVaryingExpr expr) {
|
static unsigned getHashValue(mlir::SDBMDirectExpr expr) {
|
||||||
return expr.hash_value();
|
return expr.hash_value();
|
||||||
}
|
}
|
||||||
static bool isEqual(mlir::SDBMVaryingExpr lhs, mlir::SDBMVaryingExpr rhs) {
|
static bool isEqual(mlir::SDBMDirectExpr lhs, mlir::SDBMDirectExpr rhs) {
|
||||||
return lhs == rhs;
|
return lhs == rhs;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -161,14 +161,14 @@ void SDBMExpr::print(raw_ostream &os) const {
|
||||||
Printer(raw_ostream &ostream) : prn(ostream) {}
|
Printer(raw_ostream &ostream) : prn(ostream) {}
|
||||||
|
|
||||||
void visitSum(SDBMSumExpr expr) {
|
void visitSum(SDBMSumExpr expr) {
|
||||||
visitVarying(expr.getLHS());
|
visit(expr.getLHS());
|
||||||
prn << " + ";
|
prn << " + ";
|
||||||
visitConstant(expr.getRHS());
|
visit(expr.getRHS());
|
||||||
}
|
}
|
||||||
void visitDiff(SDBMDiffExpr expr) {
|
void visitDiff(SDBMDiffExpr expr) {
|
||||||
visitTerm(expr.getLHS());
|
visit(expr.getLHS());
|
||||||
prn << " - ";
|
prn << " - ";
|
||||||
visitTerm(expr.getRHS());
|
visit(expr.getRHS());
|
||||||
}
|
}
|
||||||
void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
|
void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
|
||||||
void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
|
void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
|
||||||
|
@ -178,8 +178,13 @@ void SDBMExpr::print(raw_ostream &os) const {
|
||||||
visitConstant(expr.getStripeFactor());
|
visitConstant(expr.getStripeFactor());
|
||||||
}
|
}
|
||||||
void visitNeg(SDBMNegExpr expr) {
|
void visitNeg(SDBMNegExpr expr) {
|
||||||
|
bool isSum = expr.getVar().isa<SDBMSumExpr>();
|
||||||
prn << '-';
|
prn << '-';
|
||||||
visitTerm(expr.getVar());
|
if (isSum)
|
||||||
|
prn << '(';
|
||||||
|
visit(expr.getVar());
|
||||||
|
if (isSum)
|
||||||
|
prn << ')';
|
||||||
}
|
}
|
||||||
void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
|
void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
|
||||||
|
|
||||||
|
@ -199,7 +204,7 @@ namespace {
|
||||||
struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
|
struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
|
||||||
// Any term expression is wrapped into a negation expression.
|
// Any term expression is wrapped into a negation expression.
|
||||||
// -(x) = -x
|
// -(x) = -x
|
||||||
SDBMExpr visitTerm(SDBMTermExpr expr) { return SDBMNegExpr::get(expr); }
|
SDBMExpr visitDirect(SDBMDirectExpr expr) { return SDBMNegExpr::get(expr); }
|
||||||
// A negation expression is unwrapped.
|
// A negation expression is unwrapped.
|
||||||
// -(-x) = x
|
// -(-x) = x
|
||||||
SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
|
SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
|
||||||
|
@ -207,15 +212,20 @@ struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
|
||||||
SDBMExpr visitConstant(SDBMConstantExpr expr) {
|
SDBMExpr visitConstant(SDBMConstantExpr expr) {
|
||||||
return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
|
return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
|
||||||
}
|
}
|
||||||
// Both terms of the sum are negated recursively.
|
|
||||||
SDBMExpr visitSum(SDBMSumExpr expr) {
|
// Terms of a difference are interchanged. Since only the LHS of a diff
|
||||||
return SDBMSumExpr::get(visit(expr.getLHS()).cast<SDBMVaryingExpr>(),
|
// expression is allowed to be a sum with a constant, we need to recreate the
|
||||||
visit(expr.getRHS()).cast<SDBMConstantExpr>());
|
// sum with the negated value:
|
||||||
}
|
// -((x + C) - y) = (y - C) - x.
|
||||||
// Terms of a difference are interchanged.
|
|
||||||
// -(x - y) = y - x
|
|
||||||
SDBMExpr visitDiff(SDBMDiffExpr expr) {
|
SDBMExpr visitDiff(SDBMDiffExpr expr) {
|
||||||
return SDBMDiffExpr::get(expr.getRHS(), expr.getLHS());
|
// If the LHS is just a term, we can do straightforward interchange.
|
||||||
|
if (auto term = expr.getLHS().dyn_cast<SDBMTermExpr>())
|
||||||
|
return SDBMDiffExpr::get(expr.getRHS(), term);
|
||||||
|
|
||||||
|
auto sum = expr.getLHS().cast<SDBMSumExpr>();
|
||||||
|
auto cst = visitConstant(sum.getRHS()).cast<SDBMConstantExpr>();
|
||||||
|
return SDBMDiffExpr::get(SDBMSumExpr::get(expr.getRHS(), cst),
|
||||||
|
sum.getLHS());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -226,7 +236,7 @@ SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); }
|
||||||
// SDBMSumExpr
|
// SDBMSumExpr
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
|
SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
|
||||||
assert(lhs && "expected SDBM variable expression");
|
assert(lhs && "expected SDBM variable expression");
|
||||||
assert(rhs && "expected SDBM constant");
|
assert(rhs && "expected SDBM constant");
|
||||||
|
|
||||||
|
@ -242,8 +252,8 @@ SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
|
||||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
|
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMVaryingExpr SDBMSumExpr::getLHS() const {
|
SDBMTermExpr SDBMSumExpr::getLHS() const {
|
||||||
return static_cast<ImplType *>(impl)->lhs;
|
return static_cast<ImplType *>(impl)->lhs.cast<SDBMTermExpr>();
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMConstantExpr SDBMSumExpr::getRHS() const {
|
SDBMConstantExpr SDBMSumExpr::getRHS() const {
|
||||||
|
@ -289,24 +299,126 @@ AffineExpr SDBMExpr::getAsAffineExpr() const {
|
||||||
return converter.visit(*this);
|
return converter.visit(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Given a direct expression `expr`, add the given constant to it and pass the
|
||||||
|
// resulting expression to `builder` before returning its result. If the
|
||||||
|
// expression is already a sum expression, update its constant and extract the
|
||||||
|
// LHS if the constant becomes zero. Otherwise, construct a sum expression.
|
||||||
|
template <typename Result>
|
||||||
|
Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant, bool negated,
|
||||||
|
llvm::function_ref<Result(SDBMDirectExpr)> builder) {
|
||||||
|
SDBMDialect *dialect = expr.getDialect();
|
||||||
|
if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
|
||||||
|
if (negated)
|
||||||
|
constant = sumExpr.getRHS().getValue() - constant;
|
||||||
|
else
|
||||||
|
constant += sumExpr.getRHS().getValue();
|
||||||
|
|
||||||
|
if (constant != 0) {
|
||||||
|
auto sum = SDBMSumExpr::get(sumExpr.getLHS(),
|
||||||
|
SDBMConstantExpr::get(dialect, constant));
|
||||||
|
return builder(sum);
|
||||||
|
} else {
|
||||||
|
return builder(sumExpr.getLHS());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (constant != 0)
|
||||||
|
return builder(SDBMSumExpr::get(
|
||||||
|
expr.cast<SDBMTermExpr>(),
|
||||||
|
SDBMConstantExpr::get(dialect, negated ? -constant : constant)));
|
||||||
|
return expr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct an expression lhs + constant while maintaining the canonical form
|
||||||
|
// of the SDBM expressions, in particular sink the constant expression to the
|
||||||
|
// nearest sum expression in the left subtree of the expresison tree.
|
||||||
|
static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) {
|
||||||
|
if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
|
||||||
|
return addConstantAndSink<SDBMExpr>(
|
||||||
|
lhsDiff.getLHS(), constant, /*negated=*/false,
|
||||||
|
[lhsDiff](SDBMDirectExpr e) {
|
||||||
|
return SDBMDiffExpr::get(e, lhsDiff.getRHS());
|
||||||
|
});
|
||||||
|
if (auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>())
|
||||||
|
return addConstantAndSink<SDBMExpr>(
|
||||||
|
lhsNeg.getVar(), constant, /*negated=*/true,
|
||||||
|
[](SDBMDirectExpr e) { return SDBMNegExpr::get(e); });
|
||||||
|
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>())
|
||||||
|
return addConstantAndSink<SDBMExpr>(lhsSum, constant, /*negated=*/false,
|
||||||
|
[](SDBMDirectExpr e) { return e; });
|
||||||
|
if (constant != 0)
|
||||||
|
return SDBMSumExpr::get(lhs.cast<SDBMTermExpr>(),
|
||||||
|
SDBMConstantExpr::get(lhs.getDialect(), constant));
|
||||||
|
return lhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a difference expression given a direct expression and a negation
|
||||||
|
// expression.
|
||||||
|
static SDBMExpr buildDiffExpr(SDBMDirectExpr lhs, SDBMNegExpr rhs) {
|
||||||
|
SDBMTermExpr lhsTerm, rhsTerm;
|
||||||
|
int lhsConstant = 0;
|
||||||
|
int64_t rhsConstant = 0;
|
||||||
|
|
||||||
|
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
|
||||||
|
lhsConstant = lhsSum.getRHS().getValue();
|
||||||
|
lhsTerm = lhsSum.getLHS();
|
||||||
|
} else {
|
||||||
|
lhsTerm = lhs.cast<SDBMTermExpr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto rhsNegatedSum = rhs.getVar().dyn_cast<SDBMSumExpr>()) {
|
||||||
|
rhsTerm = rhsNegatedSum.getLHS();
|
||||||
|
rhsConstant = rhsNegatedSum.getRHS().getValue();
|
||||||
|
} else {
|
||||||
|
rhsTerm = rhs.getVar().cast<SDBMTermExpr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fold (x + C) - (x + D) = C - D.
|
||||||
|
if (lhsTerm == rhsTerm)
|
||||||
|
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant - rhsConstant);
|
||||||
|
|
||||||
|
return SDBMDiffExpr::get(
|
||||||
|
addConstantAndSink<SDBMDirectExpr>(lhs, -rhsConstant, /*negated=*/false,
|
||||||
|
[](SDBMDirectExpr e) { return e; }),
|
||||||
|
rhsTerm);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try folding an expression (lhs + rhs) where at least one of the operands
|
||||||
|
// contains a negated variable, i.e. is a negation or a difference expression.
|
||||||
|
static SDBMExpr foldSumDiff(SDBMExpr lhs, SDBMExpr rhs) {
|
||||||
|
// If exactly one of LHS, RHS is a negation expression, we can construct
|
||||||
|
// a difference expression, which is a special kind in SDBM.
|
||||||
|
auto lhsDirect = lhs.dyn_cast<SDBMDirectExpr>();
|
||||||
|
auto rhsDirect = rhs.dyn_cast<SDBMDirectExpr>();
|
||||||
|
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
|
||||||
|
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
|
||||||
|
|
||||||
|
if (lhsDirect && rhsNeg)
|
||||||
|
return buildDiffExpr(lhsDirect, rhsNeg);
|
||||||
|
if (lhsNeg && rhsDirect)
|
||||||
|
return buildDiffExpr(rhsDirect, lhsNeg);
|
||||||
|
|
||||||
|
// If a subexpression appears in a diff expression on the LHS(RHS) of a
|
||||||
|
// sum expression where it also appears on the RHS(LHS) with the opposite
|
||||||
|
// sign, we can simplify it away and obtain the SDBM form.
|
||||||
|
// x - (x - C) = -(x - C) + x = C
|
||||||
|
// (x - C) - x = -x + (x - C) = -C
|
||||||
|
auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
|
||||||
|
auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
|
||||||
|
if (lhsNeg && rhsDiff && lhsNeg.getVar() == rhsDiff.getLHS())
|
||||||
|
return SDBMNegExpr::get(rhsDiff.getRHS());
|
||||||
|
if (lhsDirect && rhsDiff && lhsDirect == rhsDiff.getRHS())
|
||||||
|
return rhsDiff.getLHS();
|
||||||
|
if (lhsDiff && rhsNeg && lhsDiff.getLHS() == rhsNeg.getVar())
|
||||||
|
return SDBMNegExpr::get(lhsDiff.getRHS());
|
||||||
|
if (rhsDirect && lhsDiff && rhsDirect == lhsDiff.getRHS())
|
||||||
|
return lhsDiff.getLHS();
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
||||||
struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
|
struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
|
||||||
SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
|
SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
|
||||||
// Attempt to recover a stripe expression. Because AffineExprs don't have
|
|
||||||
// a first-class difference kind, we check for both x + -1 * (x mod C) and
|
|
||||||
// -1 * (x mod C) + x cases.
|
|
||||||
AffineExprMatcher x, C, m;
|
|
||||||
AffineExprMatcher pattern1 = ((x % C) * m) + x;
|
|
||||||
AffineExprMatcher pattern2 = x + ((x % C) * m);
|
|
||||||
if ((pattern1.match(expr) && m.getMatchedConstantValue() == -1) ||
|
|
||||||
(pattern2.match(expr) && m.getMatchedConstantValue() == -1)) {
|
|
||||||
if (auto convertedLHS = visit(x.matched())) {
|
|
||||||
// TODO(ntv): return convertedLHS.stripe(C);
|
|
||||||
return SDBMStripeExpr::get(
|
|
||||||
convertedLHS.cast<SDBMTermExpr>(),
|
|
||||||
visit(C.matched()).cast<SDBMConstantExpr>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
|
auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
|
||||||
if (!lhs || !rhs)
|
if (!lhs || !rhs)
|
||||||
return {};
|
return {};
|
||||||
|
@ -314,29 +426,22 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
||||||
// In a "add" AffineExpr, the constant always appears on the right. If
|
// In a "add" AffineExpr, the constant always appears on the right. If
|
||||||
// there were two constants, they would have been folded away.
|
// there were two constants, they would have been folded away.
|
||||||
assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
|
assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
|
||||||
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
|
||||||
|
|
||||||
// SDBM accepts LHS variables and RHS constants in a sum.
|
// If RHS is a constant, we can always extend the SDBM expression to
|
||||||
auto lhsVar = lhs.dyn_cast<SDBMVaryingExpr>();
|
// include it by sinking the constant into the nearest sum expresion.
|
||||||
auto rhsVar = rhs.dyn_cast<SDBMVaryingExpr>();
|
if (auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>()) {
|
||||||
if (rhsConstant && lhsVar)
|
assert(!lhs.isa<SDBMSumExpr>() && "unexpected non-canonicalized sum");
|
||||||
return SDBMSumExpr::get(lhsVar, rhsConstant);
|
|
||||||
|
|
||||||
// The sum of a negated variable and a non-negated variable is a
|
int64_t constant = rhsConstant.getValue();
|
||||||
// difference, supported as a special kind in SDBM. Because AffineExprs
|
auto varying = lhs.dyn_cast<SDBMVaryingExpr>();
|
||||||
// don't have first-class difference kind, check both LHS and RHS for
|
assert(varying && "unexpected uncanonicalized sum of constants");
|
||||||
// negation.
|
return addConstant(varying, constant);
|
||||||
auto lhsPos = lhs.dyn_cast<SDBMTermExpr>();
|
}
|
||||||
auto rhsPos = rhs.dyn_cast<SDBMTermExpr>();
|
|
||||||
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
|
|
||||||
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
|
|
||||||
if (lhsNeg && rhsVar)
|
|
||||||
return SDBMDiffExpr::get(rhsPos, lhsNeg.getVar());
|
|
||||||
if (rhsNeg && lhsVar)
|
|
||||||
return SDBMDiffExpr::get(lhsPos, rhsNeg.getVar());
|
|
||||||
|
|
||||||
// Other cases don't fit into SDBM.
|
// Try building a difference expression if one of the values is negated,
|
||||||
return {};
|
// or check if a difference on either hand side cancels out the outer term
|
||||||
|
// so as to remain correct within SDBM. Return null otherwise.
|
||||||
|
return foldSumDiff(lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
|
SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
|
||||||
|
@ -367,9 +472,13 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
||||||
|
|
||||||
// The only supported "multiplication" expression is an SDBM is dimension
|
// The only supported "multiplication" expression is an SDBM is dimension
|
||||||
// negation, that is a product of dimension and constant -1.
|
// negation, that is a product of dimension and constant -1.
|
||||||
auto lhsVar = lhs.dyn_cast<SDBMTermExpr>();
|
if (rhsConstant.getValue() != -1)
|
||||||
if (lhsVar && rhsConstant.getValue() == -1)
|
return {};
|
||||||
|
|
||||||
|
if (auto lhsVar = lhs.dyn_cast<SDBMTermExpr>())
|
||||||
return SDBMNegExpr::get(lhsVar);
|
return SDBMNegExpr::get(lhsVar);
|
||||||
|
if (auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>())
|
||||||
|
return SDBMNegator().visitDiff(lhsDiff);
|
||||||
|
|
||||||
// Other multiplications are not allowed in SDBM.
|
// Other multiplications are not allowed in SDBM.
|
||||||
return {};
|
return {};
|
||||||
|
@ -383,7 +492,7 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
||||||
// 'mod' can only be converted to SDBM if its LHS is a variable
|
// 'mod' can only be converted to SDBM if its LHS is a variable
|
||||||
// and its RHS is a constant. Then it `x mod c = x - x stripe c`.
|
// and its RHS is a constant. Then it `x mod c = x - x stripe c`.
|
||||||
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
|
||||||
auto lhsVar = rhs.dyn_cast<SDBMTermExpr>();
|
auto lhsVar = lhs.dyn_cast<SDBMTermExpr>();
|
||||||
if (!lhsVar || !rhsConstant)
|
if (!lhsVar || !rhsConstant)
|
||||||
return {};
|
return {};
|
||||||
return SDBMDiffExpr::get(lhsVar,
|
return SDBMDiffExpr::get(lhsVar,
|
||||||
|
@ -418,7 +527,7 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
||||||
// SDBMDiffExpr
|
// SDBMDiffExpr
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
SDBMDiffExpr SDBMDiffExpr::get(SDBMTermExpr lhs, SDBMTermExpr rhs) {
|
SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
|
||||||
assert(lhs && "expected SDBM dimension");
|
assert(lhs && "expected SDBM dimension");
|
||||||
assert(rhs && "expected SDBM dimension");
|
assert(rhs && "expected SDBM dimension");
|
||||||
|
|
||||||
|
@ -427,7 +536,7 @@ SDBMDiffExpr SDBMDiffExpr::get(SDBMTermExpr lhs, SDBMTermExpr rhs) {
|
||||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
|
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMTermExpr SDBMDiffExpr::getLHS() const {
|
SDBMDirectExpr SDBMDiffExpr::getLHS() const {
|
||||||
return static_cast<ImplType *>(impl)->lhs;
|
return static_cast<ImplType *>(impl)->lhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -526,31 +635,36 @@ int64_t SDBMConstantExpr::getValue() const {
|
||||||
// SDBMNegExpr
|
// SDBMNegExpr
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
SDBMNegExpr SDBMNegExpr::get(SDBMTermExpr var) {
|
SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
|
||||||
assert(var && "expected non-null SDBM variable expression");
|
assert(var && "expected non-null SDBM direct expression");
|
||||||
|
|
||||||
StorageUniquer &uniquer = var.getDialect()->getUniquer();
|
StorageUniquer &uniquer = var.getDialect()->getUniquer();
|
||||||
return uniquer.get<detail::SDBMNegExprStorage>(
|
return uniquer.get<detail::SDBMNegExprStorage>(
|
||||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
|
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMTermExpr SDBMNegExpr::getVar() const {
|
SDBMDirectExpr SDBMNegExpr::getVar() const {
|
||||||
return static_cast<ImplType *>(impl)->dim;
|
return static_cast<ImplType *>(impl)->expr;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace ops_assertions {
|
namespace ops_assertions {
|
||||||
|
|
||||||
SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) {
|
SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) {
|
||||||
// If one of the operands is a negation, take a difference rather than a sum.
|
if (auto folded = foldSumDiff(lhs, rhs))
|
||||||
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
|
return folded;
|
||||||
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
|
assert(!(lhs.isa<SDBMNegExpr>() && rhs.isa<SDBMNegExpr>()) &&
|
||||||
assert(!(lhsNeg && rhsNeg) && "a sum of negated expressions is a negation of "
|
"a sum of negated expressions is a negation of a sum of variables and "
|
||||||
"a sum of variables and not a correct SDBM");
|
"not a correct SDBM");
|
||||||
if (lhsNeg)
|
|
||||||
return rhs - lhsNeg.getVar();
|
// Fold (x - y) + (y - x) = 0.
|
||||||
if (rhsNeg)
|
auto lhsDiff = lhs.dyn_cast<SDBMDiffExpr>();
|
||||||
return lhs - rhsNeg.getVar();
|
auto rhsDiff = rhs.dyn_cast<SDBMDiffExpr>();
|
||||||
|
if (lhsDiff && rhsDiff) {
|
||||||
|
if (lhsDiff.getLHS() == rhsDiff.getRHS() &&
|
||||||
|
lhsDiff.getRHS() == rhsDiff.getLHS())
|
||||||
|
return SDBMConstantExpr::get(lhs.getDialect(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
// If LHS is a constant and RHS is not, swap the order to get into a supported
|
// If LHS is a constant and RHS is not, swap the order to get into a supported
|
||||||
// sum case. From now on, RHS must be a constant.
|
// sum case. From now on, RHS must be a constant.
|
||||||
|
@ -562,26 +676,11 @@ SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) {
|
||||||
}
|
}
|
||||||
assert(rhsConstant && "at least one operand must be a constant");
|
assert(rhsConstant && "at least one operand must be a constant");
|
||||||
|
|
||||||
// If LHS is another sum, first compute the sum of its variable
|
// Constant-fold if LHS is also a constant.
|
||||||
// part with the other argument and then add the constant part to enable
|
|
||||||
// constant folding (the variable part may, e.g., be a negation that requires
|
|
||||||
// to enter this function again).
|
|
||||||
auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
|
|
||||||
if (lhsSum)
|
|
||||||
return lhsSum.getLHS() +
|
|
||||||
(lhsSum.getRHS().getValue() + rhsConstant.getValue());
|
|
||||||
|
|
||||||
// Constant-fold if LHS is a constant.
|
|
||||||
if (lhsConstant)
|
if (lhsConstant)
|
||||||
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
|
return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
|
||||||
rhsConstant.getValue());
|
rhsConstant.getValue());
|
||||||
|
return addConstant(lhs.cast<SDBMVaryingExpr>(), rhsConstant.getValue());
|
||||||
// Fold x + 0 == x.
|
|
||||||
if (rhsConstant.getValue() == 0)
|
|
||||||
return lhs;
|
|
||||||
|
|
||||||
return SDBMSumExpr::get(lhs.cast<SDBMVaryingExpr>(),
|
|
||||||
rhs.cast<SDBMConstantExpr>());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) {
|
SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) {
|
||||||
|
@ -608,25 +707,7 @@ SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) {
|
||||||
if (lhsConstant)
|
if (lhsConstant)
|
||||||
return -rhs + lhsConstant;
|
return -rhs + lhsConstant;
|
||||||
|
|
||||||
// Hoist constant factors outside the difference if any of sides is a sum:
|
return buildDiffExpr(lhs.cast<SDBMDirectExpr>(), (-rhs).cast<SDBMNegExpr>());
|
||||||
// (x + A) - (y - B) == x - y + (A - B).
|
|
||||||
// If either LHS or RHS is a sum, collect the constant values separately and
|
|
||||||
// update LHS and RHS to point to the variable part of the sum.
|
|
||||||
auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
|
|
||||||
auto rhsSum = rhs.dyn_cast<SDBMSumExpr>();
|
|
||||||
int64_t value = 0;
|
|
||||||
if (lhsSum) {
|
|
||||||
value += lhsSum.getRHS().getValue();
|
|
||||||
lhs = lhsSum.getLHS();
|
|
||||||
}
|
|
||||||
if (rhsSum) {
|
|
||||||
value -= rhsSum.getRHS().getValue();
|
|
||||||
rhs = rhsSum.getLHS();
|
|
||||||
}
|
|
||||||
|
|
||||||
// This calls into operator+ for futher simplification in case value == 0.
|
|
||||||
return SDBMDiffExpr::get(lhs.cast<SDBMTermExpr>(), rhs.cast<SDBMTermExpr>()) +
|
|
||||||
value;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) {
|
SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) {
|
||||||
|
|
|
@ -43,7 +43,7 @@ struct SDBMExprStorage : public StorageUniquer::BaseStorage {
|
||||||
|
|
||||||
// Storage class for SDBM sum and stripe expressions.
|
// Storage class for SDBM sum and stripe expressions.
|
||||||
struct SDBMBinaryExprStorage : public SDBMExprStorage {
|
struct SDBMBinaryExprStorage : public SDBMExprStorage {
|
||||||
using KeyTy = std::pair<SDBMVaryingExpr, SDBMConstantExpr>;
|
using KeyTy = std::pair<SDBMDirectExpr, SDBMConstantExpr>;
|
||||||
|
|
||||||
bool operator==(const KeyTy &key) const {
|
bool operator==(const KeyTy &key) const {
|
||||||
return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
|
return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
|
||||||
|
@ -58,13 +58,13 @@ struct SDBMBinaryExprStorage : public SDBMExprStorage {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMVaryingExpr lhs;
|
SDBMDirectExpr lhs;
|
||||||
SDBMConstantExpr rhs;
|
SDBMConstantExpr rhs;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Storage class for SDBM difference expressions.
|
// Storage class for SDBM difference expressions.
|
||||||
struct SDBMDiffExprStorage : public SDBMExprStorage {
|
struct SDBMDiffExprStorage : public SDBMExprStorage {
|
||||||
using KeyTy = std::pair<SDBMTermExpr, SDBMTermExpr>;
|
using KeyTy = std::pair<SDBMDirectExpr, SDBMTermExpr>;
|
||||||
|
|
||||||
bool operator==(const KeyTy &key) const {
|
bool operator==(const KeyTy &key) const {
|
||||||
return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
|
return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
|
||||||
|
@ -79,7 +79,7 @@ struct SDBMDiffExprStorage : public SDBMExprStorage {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMTermExpr lhs;
|
SDBMDirectExpr lhs;
|
||||||
SDBMTermExpr rhs;
|
SDBMTermExpr rhs;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -117,19 +117,19 @@ struct SDBMTermExprStorage : public SDBMExprStorage {
|
||||||
|
|
||||||
// Storage class for SDBM negation expressions.
|
// Storage class for SDBM negation expressions.
|
||||||
struct SDBMNegExprStorage : public SDBMExprStorage {
|
struct SDBMNegExprStorage : public SDBMExprStorage {
|
||||||
using KeyTy = SDBMTermExpr;
|
using KeyTy = SDBMDirectExpr;
|
||||||
|
|
||||||
bool operator==(const KeyTy &key) const { return key == dim; }
|
bool operator==(const KeyTy &key) const { return key == expr; }
|
||||||
|
|
||||||
static SDBMNegExprStorage *
|
static SDBMNegExprStorage *
|
||||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||||
auto *result = allocator.allocate<SDBMNegExprStorage>();
|
auto *result = allocator.allocate<SDBMNegExprStorage>();
|
||||||
result->dim = key;
|
result->expr = key;
|
||||||
result->dialect = key.getDialect();
|
result->dialect = key.getDialect();
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMTermExpr dim;
|
SDBMDirectExpr expr;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace detail
|
} // end namespace detail
|
||||||
|
|
|
@ -161,13 +161,13 @@ TEST_FUNC(SDBM_StripeTightening) {
|
||||||
|
|
||||||
SmallVector<SDBMExpr, 4> eqs, ineqs;
|
SmallVector<SDBMExpr, 4> eqs, ineqs;
|
||||||
sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
|
sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
|
||||||
// CHECK: s0 # 3 - d0 + -2
|
// CHECK: s0 # 3 + -2 - d0
|
||||||
// CHECK-EMPTY:
|
// CHECK-EMPTY:
|
||||||
for (auto ineq : ineqs)
|
for (auto ineq : ineqs)
|
||||||
ineq.print(llvm::outs() << '\n');
|
ineq.print(llvm::outs() << '\n');
|
||||||
llvm::outs() << "\n";
|
llvm::outs() << "\n";
|
||||||
|
|
||||||
// CHECK-DAG: d1 - d0 + -42
|
// CHECK-DAG: d1 + -42 - d0
|
||||||
// CHECK-DAG: d0 - s0 # 3 # 5
|
// CHECK-DAG: d0 - s0 # 3 # 5
|
||||||
for (auto eq : eqs)
|
for (auto eq : eqs)
|
||||||
eq.print(llvm::outs() << '\n');
|
eq.print(llvm::outs() << '\n');
|
||||||
|
|
|
@ -76,6 +76,28 @@ TEST(SDBMOperators, AddFolding) {
|
||||||
|
|
||||||
auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0);
|
auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0);
|
||||||
EXPECT_EQ(inverted, expr);
|
EXPECT_EQ(inverted, expr);
|
||||||
|
|
||||||
|
// Check that opposite values cancel each other, and that we elide the zero
|
||||||
|
// constant.
|
||||||
|
expr = dim(0) + 42;
|
||||||
|
auto onlyDim = expr - 42;
|
||||||
|
EXPECT_EQ(onlyDim, dim(0));
|
||||||
|
|
||||||
|
// Check that we can sink a constant under a negation.
|
||||||
|
expr = -(dim(0) + 2);
|
||||||
|
auto negatedSum = (expr + 10).dyn_cast<SDBMNegExpr>();
|
||||||
|
ASSERT_TRUE(negatedSum);
|
||||||
|
auto sum = negatedSum.getVar().dyn_cast<SDBMSumExpr>();
|
||||||
|
ASSERT_TRUE(sum);
|
||||||
|
EXPECT_EQ(sum.getRHS().getValue(), -8);
|
||||||
|
|
||||||
|
// Sum with zero is the same as the original expression.
|
||||||
|
EXPECT_EQ(dim(0) + 0, dim(0));
|
||||||
|
|
||||||
|
// Sum of opposite differences is zero.
|
||||||
|
auto diffOfDiffs =
|
||||||
|
((dim(0) - dim(1)) + (dim(1) - dim(0))).dyn_cast<SDBMConstantExpr>();
|
||||||
|
EXPECT_EQ(diffOfDiffs.getValue(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SDBMOperators, Diff) {
|
TEST(SDBMOperators, Diff) {
|
||||||
|
@ -101,6 +123,43 @@ TEST(SDBMOperators, DiffFolding) {
|
||||||
constantExpr = zero.dyn_cast<SDBMConstantExpr>();
|
constantExpr = zero.dyn_cast<SDBMConstantExpr>();
|
||||||
ASSERT_TRUE(constantExpr);
|
ASSERT_TRUE(constantExpr);
|
||||||
EXPECT_EQ(constantExpr.getValue(), 0);
|
EXPECT_EQ(constantExpr.getValue(), 0);
|
||||||
|
|
||||||
|
// Check that the constant terms in difference-of-sums are folded.
|
||||||
|
// (d0 - 3) - (d1 - 5) = (d0 + 2) - d1
|
||||||
|
auto diffOfSums = ((dim(0) - 3) - (dim(1) - 5)).dyn_cast<SDBMDiffExpr>();
|
||||||
|
ASSERT_TRUE(diffOfSums);
|
||||||
|
auto lhs = diffOfSums.getLHS().dyn_cast<SDBMSumExpr>();
|
||||||
|
ASSERT_TRUE(lhs);
|
||||||
|
EXPECT_EQ(lhs.getLHS(), dim(0));
|
||||||
|
EXPECT_EQ(lhs.getRHS().getValue(), 2);
|
||||||
|
EXPECT_EQ(diffOfSums.getRHS(), dim(1));
|
||||||
|
|
||||||
|
// Check that identical dimensions with opposite signs cancel each other.
|
||||||
|
auto cstOnly = ((dim(0) + 42) - dim(0)).dyn_cast<SDBMConstantExpr>();
|
||||||
|
ASSERT_TRUE(cstOnly);
|
||||||
|
EXPECT_EQ(cstOnly.getValue(), 42);
|
||||||
|
|
||||||
|
// Check that identical terms in sum of diffs cancel out.
|
||||||
|
auto dimOnly = (-dim(0) + (dim(0) - dim(1)));
|
||||||
|
EXPECT_EQ(dimOnly, -dim(1));
|
||||||
|
dimOnly = (dim(0) - dim(1)) + (-dim(0));
|
||||||
|
EXPECT_EQ(dimOnly, -dim(1));
|
||||||
|
dimOnly = (dim(0) - dim(1)) + dim(1);
|
||||||
|
EXPECT_EQ(dimOnly, dim(0));
|
||||||
|
dimOnly = dim(0) + (dim(1) - dim(0));
|
||||||
|
EXPECT_EQ(dimOnly, dim(1));
|
||||||
|
|
||||||
|
// Top-level zero constant is fine.
|
||||||
|
cstOnly = (-symb(1) + symb(1)).dyn_cast<SDBMConstantExpr>();
|
||||||
|
ASSERT_TRUE(cstOnly);
|
||||||
|
EXPECT_EQ(cstOnly.getValue(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SDBMOperators, Negate) {
|
||||||
|
auto sum = dim(0) + 3;
|
||||||
|
auto negated = (-sum).dyn_cast<SDBMNegExpr>();
|
||||||
|
ASSERT_TRUE(negated);
|
||||||
|
EXPECT_EQ(negated.getVar(), sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SDBMOperators, Stripe) {
|
TEST(SDBMOperators, Stripe) {
|
||||||
|
@ -324,11 +383,11 @@ TEST(SDBMExpr, AffineRoundTrip) {
|
||||||
|
|
||||||
// Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a
|
// Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a
|
||||||
// deeper expression tree.
|
// deeper expression tree.
|
||||||
auto diff = SDBMDiffExpr::get(outerStripe, stripe);
|
auto sum = SDBMSumExpr::get(outerStripe, cst2);
|
||||||
auto sum = SDBMSumExpr::get(diff, cst2);
|
auto diff = SDBMDiffExpr::get(sum, stripe);
|
||||||
roundtripped = SDBMExpr::tryConvertAffineExpr(sum.getAsAffineExpr());
|
roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
|
||||||
ASSERT_TRUE(roundtripped.hasValue());
|
ASSERT_TRUE(roundtripped.hasValue());
|
||||||
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(sum));
|
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SDBMExpr, MatchStripeMulPattern) {
|
TEST(SDBMExpr, MatchStripeMulPattern) {
|
||||||
|
|
Loading…
Reference in New Issue