Overload arithmetic operators for SDBM expressions

Provide an "unsafe" version of the overloaded arithmetic operators for SDBM
    expressions.  These operators expect the operands to be of the right SDBM
    expression subtype and assert if they are not.  They also perform simple
    folding operations as well as some semantically correct operations that
    construct an SDBM expression of a different subtype, e.g., a difference
    expression if the RHS of an operator+ is a negated variable.  These operators
    are scoped in a namespace to allow for a future "safe" version of the operators
    that propagates null expressions to denote the error state when expressions
    have wrong subtypes.

--

PiperOrigin-RevId: 248704153
This commit is contained in:
Alex Zinenko 2019-05-17 05:40:51 -07:00 committed by Mehdi Amini
parent f06ab26acf
commit 69ef8642df
4 changed files with 239 additions and 88 deletions

View File

@ -43,6 +43,10 @@ struct SDBMConstantExprStorage;
struct SDBMNegExprStorage;
} // namespace detail
class SDBMConstantExpr;
class SDBMDimExpr;
class SDBMSymbolExpr;
/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side
/// expression for the SDBM framework. SDBM expressions are a subset of affine
/// expressions supporting low-complexity algorithms for the operations used in
@ -400,6 +404,40 @@ protected:
}
};
/// Overloaded arithmetic operators for SDBM expressions asserting that their
/// arguments have the proper SDBM expression subtype. Perform canonicalization
/// and constant folding on these expressions.
namespace ops_assertions {
/// Add two SDBM expressions. At least one of the expressions must be a
/// constant or a negation, but both expressions cannot be negations
/// simultaneously.
SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs);
inline SDBMExpr operator+(SDBMExpr lhs, int64_t rhs) {
return lhs + SDBMConstantExpr::get(lhs.getContext(), rhs);
}
inline SDBMExpr operator+(int64_t lhs, SDBMExpr rhs) {
return SDBMConstantExpr::get(rhs.getContext(), lhs) + rhs;
}
/// Subtract an SDBM expression from another SDBM expression. Both expressions
/// must not be difference expressions.
SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs);
inline SDBMExpr operator-(SDBMExpr lhs, int64_t rhs) {
return lhs - SDBMConstantExpr::get(lhs.getContext(), rhs);
}
inline SDBMExpr operator-(int64_t lhs, SDBMExpr rhs) {
return SDBMConstantExpr::get(rhs.getContext(), lhs) - rhs;
}
/// Construct a stripe expression from a positive expression and a positive
/// constant stripe factor.
SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor);
inline SDBMExpr stripe(SDBMExpr expr, int64_t factor) {
return stripe(expr, SDBMConstantExpr::get(expr.getContext(), factor));
}
} // namespace ops_assertions
} // end namespace mlir
namespace llvm {

View File

@ -353,16 +353,16 @@ void SDBM::convertDBMElement(MLIRContext *context, unsigned row, unsigned col,
SDBMPositiveExpr rowExpr, SDBMPositiveExpr colExpr,
SmallVectorImpl<SDBMExpr> &inequalities,
SmallVectorImpl<SDBMExpr> &equalities) {
using ops_assertions::operator+;
using ops_assertions::operator-;
auto diffIJValue = at(col, row);
auto diffJIValue = at(row, col);
auto diffIJValueExpr =
SDBMConstantExpr::get(context, -diffIJValue.getValue());
auto diffIJExpr = SDBMDiffExpr::get(rowExpr, colExpr);
// If symmetric entries are equal, so are the corresponding expressions.
if (diffIJValue.isFinite() &&
diffIJValue.getValue() == -diffJIValue.getValue()) {
equalities.push_back(SDBMSumExpr::get(diffIJExpr, diffIJValueExpr));
equalities.push_back(rowExpr - colExpr - diffIJValue.getValue());
return;
}
@ -399,15 +399,12 @@ void SDBM::convertDBMElement(MLIRContext *context, unsigned row, unsigned col,
// Check row - col.
if (diffIJValue.isFinite() &&
!canElide(row, col, rowExpr, colExpr, diffIJValue.getValue())) {
inequalities.push_back(SDBMSumExpr::get(diffIJExpr, diffIJValueExpr));
inequalities.push_back(rowExpr - colExpr - diffIJValue.getValue());
}
// Check col - row.
if (diffJIValue.isFinite() &&
!canElide(col, row, colExpr, rowExpr, diffJIValue.getValue())) {
auto diffJIExpr = SDBMDiffExpr::get(colExpr, rowExpr);
auto diffJIValueExpr =
SDBMConstantExpr::get(context, -diffJIValue.getValue());
inequalities.push_back(SDBMSumExpr::get(diffJIExpr, diffJIValueExpr));
inequalities.push_back(colExpr - rowExpr - diffJIValue.getValue());
}
}
@ -421,17 +418,18 @@ void SDBM::convertDBMDiagonalElement(MLIRContext *context, unsigned pos,
SmallVectorImpl<SDBMExpr> &inequalities) {
auto selfDifference = at(pos, pos);
if (selfDifference.isFinite() && selfDifference < 0) {
auto selfDifferenceExpr = SDBMDiffExpr::get(expr, expr);
auto selfDifferenceValueExpr =
SDBMConstantExpr::get(context, -selfDifference.getValue());
inequalities.push_back(
SDBMSumExpr::get(selfDifferenceExpr, selfDifferenceValueExpr));
inequalities.push_back(selfDifferenceValueExpr);
}
}
void SDBM::getSDBMExpressions(MLIRContext *context,
SmallVectorImpl<SDBMExpr> &inequalities,
SmallVectorImpl<SDBMExpr> &equalities) {
using ops_assertions::operator-;
using ops_assertions::operator+;
// Helper function that creates an SDBMInputExpr given the linearized position
// of variable in the DBM.
auto getInput = [context, this](unsigned matrixPos) -> SDBMInputExpr {
@ -457,19 +455,14 @@ void SDBM::getSDBMExpressions(MLIRContext *context,
// each variable. Transform them into inequalities if they are finite.
auto upperBound = at(0, 1 + i);
auto lowerBound = at(1 + i, 0);
auto upperBoundExpr =
SDBMConstantExpr::get(context, -upperBound.getValue());
auto inputExpr = getInput(i);
if (upperBound.isFinite() &&
upperBound.getValue() == -lowerBound.getValue()) {
equalities.push_back(SDBMSumExpr::get(inputExpr, upperBoundExpr));
equalities.push_back(inputExpr - upperBound.getValue());
} else if (upperBound.isFinite()) {
inequalities.push_back(SDBMSumExpr::get(inputExpr, upperBoundExpr));
inequalities.push_back(inputExpr - upperBound.getValue());
} else if (lowerBound.isFinite()) {
auto lowerBoundExpr =
SDBMConstantExpr::get(context, -lowerBound.getValue());
inequalities.push_back(
SDBMSumExpr::get(SDBMNegExpr::get(inputExpr), lowerBoundExpr));
inequalities.push_back(-inputExpr - lowerBound.getValue());
}
// Introduce trivially false inequalities if required by diagonal elements.
@ -488,8 +481,7 @@ void SDBM::getSDBMExpressions(MLIRContext *context,
for (const auto &stripePair : stripeToPoint) {
unsigned position = stripePair.first;
if (position < 1 + numTrueVariables) {
equalities.push_back(SDBMDiffExpr::get(
getInput(position - 1), stripePair.second.cast<SDBMStripeExpr>()));
equalities.push_back(getInput(position - 1) - stripePair.second);
}
}

View File

@ -441,3 +441,109 @@ int64_t SDBMConstantExpr::getValue() const {
SDBMPositiveExpr SDBMNegExpr::getVar() const {
return static_cast<ImplType *>(impl)->dim;
}
namespace mlir {
namespace ops_assertions {
SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) {
// If one of the operands is a negation, take a difference rather than a sum.
auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
assert(!(lhsNeg && rhsNeg) && "a sum of negated expressions is a negation of "
"a sum of variables and not a correct SDBM");
if (lhsNeg)
return rhs - lhsNeg.getVar();
if (rhsNeg)
return lhs - rhsNeg.getVar();
// If LHS is a constant and RHS is not, swap the order to get into a supported
// sum case. From now on, RHS must be a constant.
auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
if (!rhsConstant && lhsConstant) {
std::swap(lhs, rhs);
std::swap(lhsConstant, rhsConstant);
}
assert(rhsConstant && "at least one operand must be a constant");
// If LHS is another sum, first compute the sum of its variable
// part with the other argument and then add the constant part to enable
// constant folding (the variable part may, e.g., be a negation that requires
// to enter this function again).
auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
if (lhsSum)
return lhsSum.getLHS() +
(lhsSum.getRHS().getValue() + rhsConstant.getValue());
// Constant-fold if LHS is a constant.
if (lhsConstant)
return SDBMConstantExpr::get(lhs.getContext(), lhsConstant.getValue() +
rhsConstant.getValue());
// Fold x + 0 == x.
if (rhsConstant.getValue() == 0)
return lhs;
return SDBMSumExpr::get(lhs.cast<SDBMVaryingExpr>(),
rhs.cast<SDBMConstantExpr>());
}
SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) {
// Fold x - x == 0.
if (lhs == rhs)
return SDBMConstantExpr::get(lhs.getContext(), 0);
// LHS and RHS may be constants.
auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
// Constant fold if both LHS and RHS are constants.
if (lhsConstant && rhsConstant)
return SDBMConstantExpr::get(lhs.getContext(), lhsConstant.getValue() -
rhsConstant.getValue());
// Replace a difference with a sum with a negated value if one of LHS and RHS
// is a constant:
// x - C == x + (-C);
// C - x == -x + C.
// This calls into operator+ for further simplification.
if (rhsConstant)
return lhs + (-rhsConstant);
if (lhsConstant)
return -rhs + lhsConstant;
// Hoist constant factors outside the difference if any of sides is a sum:
// (x + A) - (y - B) == x - y + (A - B).
// If either LHS or RHS is a sum, collect the constant values separately and
// update LHS and RHS to point to the variable part of the sum.
auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
auto rhsSum = rhs.dyn_cast<SDBMSumExpr>();
int64_t value = 0;
if (lhsSum) {
value += lhsSum.getRHS().getValue();
lhs = lhsSum.getLHS();
}
if (rhsSum) {
value -= rhsSum.getRHS().getValue();
rhs = rhsSum.getLHS();
}
// This calls into operator+ for futher simplification in case value == 0.
return SDBMDiffExpr::get(lhs.cast<SDBMPositiveExpr>(),
rhs.cast<SDBMPositiveExpr>()) +
value;
}
SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) {
auto constantFactor = factor.cast<SDBMConstantExpr>();
assert(constantFactor.getValue() > 0 && "non-positive stripe");
// Fold x # 1 = x.
if (constantFactor.getValue() == 1)
return expr;
return SDBMStripeExpr::get(expr.cast<SDBMPositiveExpr>(), constantFactor);
}
} // namespace ops_assertions
} // namespace mlir

View File

@ -30,8 +30,76 @@ static MLIRContext *ctx() {
return &context;
}
static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(ctx(), pos); }
static SDBMExpr symb(unsigned pos) { return SDBMSymbolExpr::get(ctx(), pos); }
namespace {
using namespace mlir::ops_assertions;
TEST(SDBMOperators, Add) {
auto expr = dim(0) + 42;
auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(sumExpr);
EXPECT_EQ(sumExpr.getLHS(), dim(0));
EXPECT_EQ(sumExpr.getRHS().getValue(), 42);
}
TEST(SDBMOperators, AddFolding) {
auto constant = SDBMConstantExpr::get(ctx(), 2) + 42;
auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(constantExpr);
EXPECT_EQ(constantExpr.getValue(), 44);
auto expr = (dim(0) + 10) + 32;
auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(sumExpr);
EXPECT_EQ(sumExpr.getRHS().getValue(), 42);
expr = dim(0) + SDBMNegExpr::get(SDBMDimExpr::get(ctx(), 1));
auto diffExpr = expr.dyn_cast<SDBMDiffExpr>();
ASSERT_TRUE(diffExpr);
EXPECT_EQ(diffExpr.getLHS(), dim(0));
EXPECT_EQ(diffExpr.getRHS(), dim(1));
auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(ctx(), 1)) + dim(0);
EXPECT_EQ(inverted, expr);
}
TEST(SDBMOperators, Diff) {
auto expr = dim(0) - dim(1);
auto diffExpr = expr.dyn_cast<SDBMDiffExpr>();
ASSERT_TRUE(diffExpr);
EXPECT_EQ(diffExpr.getLHS(), dim(0));
EXPECT_EQ(diffExpr.getRHS(), dim(1));
}
TEST(SDBMOperators, DiffFolding) {
auto constant = SDBMConstantExpr::get(ctx(), 10) - 3;
auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(constantExpr);
EXPECT_EQ(constantExpr.getValue(), 7);
auto expr = dim(0) - 3;
auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(sumExpr);
EXPECT_EQ(sumExpr.getRHS().getValue(), -3);
auto zero = dim(0) - dim(0);
constantExpr = zero.dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(constantExpr);
EXPECT_EQ(constantExpr.getValue(), 0);
}
TEST(SDBMOperators, Stripe) {
auto expr = stripe(dim(0), 3);
auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>();
ASSERT_TRUE(stripeExpr);
EXPECT_EQ(stripeExpr.getVar(), dim(0));
EXPECT_EQ(stripeExpr.getStripeFactor().getValue(), 3);
}
TEST(SDBM, SingleConstraint) {
// Build an SDBM defined by
//
@ -40,11 +108,7 @@ TEST(SDBM, SingleConstraint) {
// cst d0
// cst inf 3
// d0 inf inf
auto cst3 = SDBMConstantExpr::get(ctx(), -3);
auto dim0 = SDBMDimExpr::get(ctx(), 0);
auto expr = SDBMSumExpr::get(dim0, cst3);
auto sdbm = SDBM::get(expr, llvm::None);
auto sdbm = SDBM::get(dim(0) - 3, llvm::None);
EXPECT_EQ(sdbm(0, 1), 3);
}
@ -58,12 +122,7 @@ TEST(SDBM, Equality) {
// cst inf inf inf
// d0 inf inf -3
// d1 inf 3 inf
auto cst3 = SDBMConstantExpr::get(ctx(), -3);
auto dim0 = SDBMDimExpr::get(ctx(), 0);
auto dim1 = SDBMDimExpr::get(ctx(), 1);
auto expr = SDBMSumExpr::get(SDBMDiffExpr::get(dim0, dim1), cst3);
auto sdbm = SDBM::get(llvm::None, expr);
auto sdbm = SDBM::get(llvm::None, dim(0) - dim(1) - 3);
EXPECT_EQ(sdbm(1, 2), -3);
EXPECT_EQ(sdbm(2, 1), 3);
}
@ -78,14 +137,8 @@ TEST(SDBM, TrivialSimplification) {
//
// cst d0
// cst inf 3
// d0 inf inf
auto cst5 = SDBMConstantExpr::get(ctx(), -5);
auto cst3 = SDBMConstantExpr::get(ctx(), -3);
auto dim0 = SDBMDimExpr::get(ctx(), 0);
auto expr5 = SDBMSumExpr::get(dim0, cst5);
auto expr3 = SDBMSumExpr::get(dim0, cst3);
auto sdbm = SDBM::get({expr3, expr5}, llvm::None);
// d0 inf inf;
auto sdbm = SDBM::get({dim(0) - 3, dim(0) - 5}, llvm::None);
EXPECT_EQ(sdbm(0, 1), 3);
}
@ -99,14 +152,7 @@ TEST(SDBM, StripeInducedIneqs) {
// cst inf inf inf
// d0 inf inf 2
// d1 inf 0 0
auto dim0 = SDBMDimExpr::get(ctx(), 0);
auto dim1 = SDBMDimExpr::get(ctx(), 1);
auto cst3 = SDBMConstantExpr::get(ctx(), 3);
auto stripe = SDBMStripeExpr::get(dim0, cst3);
auto expr = SDBMDiffExpr::get(dim1, stripe);
auto sdbm = SDBM::get(llvm::None, expr);
auto sdbm = SDBM::get(llvm::None, dim(1) - stripe(dim(0), 3));
EXPECT_EQ(sdbm(1, 2), 2);
EXPECT_EQ(sdbm(2, 1), 0);
}
@ -123,11 +169,7 @@ TEST(SDBM, StripeTemporaries) {
// cst inf inf 0
// d0 inf inf 2
// t0 inf 0 inf
auto dim0 = SDBMDimExpr::get(ctx(), 0);
auto cst3 = SDBMConstantExpr::get(ctx(), 3);
auto stripe = SDBMStripeExpr::get(dim0, cst3);
auto sdbm = SDBM::get(stripe, llvm::None);
auto sdbm = SDBM::get(stripe(dim(0), 3), llvm::None);
EXPECT_EQ(sdbm(0, 2), 0);
EXPECT_EQ(sdbm(1, 2), 2);
EXPECT_EQ(sdbm(2, 1), 0);
@ -144,19 +186,12 @@ TEST(SDBM, RoundTripEqs) {
// different due to simplification or equivalent substitutions (e.g., the
// second equality may become d0 - d1 + 42 = 0). However, there should not
// be any further simplification after the second round-trip,
auto cst3 = SDBMConstantExpr::get(ctx(), 3);
auto cst5 = SDBMConstantExpr::get(ctx(), 5);
auto dim0 = SDBMSymbolExpr::get(ctx(), 0);
auto stripe = SDBMStripeExpr::get(SDBMStripeExpr::get(dim0, cst3), cst5);
auto foo = SDBMDiffExpr::get(stripe, SDBMDimExpr::get(ctx(), 0));
auto bar =
SDBMSumExpr::get(SDBMDiffExpr::get(stripe, SDBMDimExpr::get(ctx(), 1)),
SDBMConstantExpr::get(ctx(), 42));
// Build the SDBM from a pair of equalities and extract back the lists of
// inequalities and equalities. Check that all equalities are properly
// detected and none of them decayed into inequalities.
auto sdbm = SDBM::get(llvm::None, {foo, bar});
auto s = stripe(stripe(symb(0), 3), 5);
auto sdbm = SDBM::get(llvm::None, {s - dim(0), s - dim(1) + 42});
SmallVector<SDBMExpr, 4> eqs, ineqs;
sdbm.getSDBMExpressions(ctx(), ineqs, eqs);
ASSERT_TRUE(ineqs.empty());
@ -185,20 +220,9 @@ TEST(SDBM, StripeTightening) {
// equality (d0 - s0 # 3 <= 5 - 1 = 4). Check that the conversion from SDBM
// back to the lists of constraints conserves both the stripe equality and the
// tighter inequality.
auto cst3 = SDBMConstantExpr::get(ctx(), 3);
auto cst5 = SDBMConstantExpr::get(ctx(), 5);
auto dim0 = SDBMSymbolExpr::get(ctx(), 0);
auto stripe = SDBMStripeExpr::get(SDBMStripeExpr::get(dim0, cst3), cst5);
auto foo = SDBMDiffExpr::get(stripe, SDBMDimExpr::get(ctx(), 0));
auto bar = SDBMSumExpr::get(
SDBMDiffExpr::get(SDBMDimExpr::get(ctx(), 0), SDBMDimExpr::get(ctx(), 1)),
SDBMConstantExpr::get(ctx(), 42));
auto tight =
SDBMSumExpr::get(SDBMDiffExpr::get(SDBMDimExpr::get(ctx(), 0),
SDBMStripeExpr::get(dim0, cst3)),
SDBMConstantExpr::get(ctx(), -2));
auto sdbm = SDBM::get({tight}, {foo, bar});
auto s = stripe(stripe(symb(0), 3), 5);
auto tight = dim(0) - stripe(symb(0), 3) - 2;
auto sdbm = SDBM::get({tight}, {s - dim(0), s - dim(1) + 42});
SmallVector<SDBMExpr, 4> eqs, ineqs;
sdbm.getSDBMExpressions(ctx(), ineqs, eqs);
@ -224,17 +248,8 @@ TEST(SDBM, StripeTransitive) {
// d1 inf 2 inf inf inf
// d2 inf inf inf inf 6
// t0 inf 0 inf 0 inf
auto cst3 = SDBMConstantExpr::get(ctx(), 3);
auto cst7 = SDBMConstantExpr::get(ctx(), 7);
auto dim0 = SDBMDimExpr::get(ctx(), 0);
auto dim1 = SDBMDimExpr::get(ctx(), 1);
auto dim2 = SDBMDimExpr::get(ctx(), 2);
auto stripe1 = SDBMStripeExpr::get(dim1, cst3);
auto stripe2 = SDBMStripeExpr::get(dim2, cst7);
auto diff1 = SDBMDiffExpr::get(stripe1, dim0);
auto diff2 = SDBMDiffExpr::get(stripe2, dim0);
auto sdbm = SDBM::get(llvm::None, {diff1, diff2});
auto sdbm = SDBM::get(
llvm::None, {stripe(dim(1), 3) - dim(0), stripe(dim(2), 7) - dim(0)});
// Induced by d0 = d1 # 3.
EXPECT_EQ(sdbm(1, 2), 0);
EXPECT_EQ(sdbm(2, 1), 2);