Clean up memref dep check utilities; update FlatAffineConstraints API, add

simple utility methods.

- clean up some of the analysis utilities used by memref dep checking
- add additional asserts / comments at places in analysis utilities
- add additional simple methods to the FlatAffineConstraints API.

PiperOrigin-RevId: 220124523
This commit is contained in:
Uday Bondhugula 2018-11-05 10:12:16 -08:00 committed by jpienaar
parent 9a62178372
commit 4269a01863
4 changed files with 117 additions and 29 deletions

View File

@ -37,6 +37,7 @@ class MLIRContext;
class FlatAffineConstraints;
class MLValue;
class OperationStmt;
class Statement;
/// Simplify an affine expression through flattening and some amount of
/// simple analysis. This has complexity linear in the number of nodes in
@ -66,6 +67,11 @@ bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
llvm::SmallVectorImpl<int64_t> *flattenedExpr,
FlatAffineConstraints *cst = nullptr);
/// Adds constraints capturing the index set of the ML values in indices to
/// 'domain'.
bool addIndexSet(llvm::ArrayRef<const MLValue *> indices,
FlatAffineConstraints *domain);
/// Checks whether two accesses to the same memref access the same element.
/// Each access is specified using the MemRefAccess structure, which contains
/// the operation statement, indices and memref associated with the access.

View File

@ -456,6 +456,23 @@ public:
/// Clears this list of constraints and copies other into it.
void clearAndCopyFrom(const FlatAffineConstraints &other);
/// Returns the constant lower bound of the specified identifier (through a
/// scan through the constraints); returns None if the bound isn't trivially a
/// constant.
Optional<int64_t> getConstantLowerBound(unsigned pos);
/// Returns the constant upper bound of the specified identifier (through a
/// scan through the constraints); returns None if the bound isn't trivially a
/// constant.
Optional<int64_t> getConstantUpperBound(unsigned pos);
// Returns the lower and upper bounds of the specified dimensions as
// AffineMap's. Returns false for the unimplemented cases for the moment.
bool getDimensionBounds(unsigned pos, unsigned num,
SmallVectorImpl<AffineMap> *lbs,
SmallVectorImpl<AffineMap> *ubs,
MLIRContext *context);
// More expensive ones.
void removeDuplicates();

View File

@ -216,11 +216,13 @@ public:
void visitDimExpr(AffineDimExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
assert(expr.getPosition() < numDims && "Inconsistent number of dims");
eq[getDimStartIndex() + expr.getPosition()] = 1;
}
void visitSymbolExpr(AffineSymbolExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
eq[getSymbolStartIndex() + expr.getPosition()] = 1;
}
void visitConstantExpr(AffineConstantExpr expr) {
@ -283,6 +285,8 @@ private:
bound[bound.size() - 1] = -(rhsConst - 1);
cst.addLowerBound(lhs, bound);
}
// Set the expression on stack to the local var introduced to capture the
// result of the division (floor or ceil).
std::fill(lhs.begin(), lhs.end(), 0);
lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
}
@ -420,29 +424,17 @@ void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) {
// TODO(andydavis) Handle non-unit Step by adding local variable
// (iv - lb % step = 0 introducing a method in FlatAffineConstraints
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
static bool addForStmtBounds(unsigned numDims,
ArrayRef<const MLValue *> forStmts,
FlatAffineConstraints *domain) {
assert(forStmts.size() >= numDims);
unsigned numIds = forStmts.size();
// Add InEqualties for loop bounds.
SmallVector<int64_t, 4> ineq;
ineq.resize(numIds + 1);
for (unsigned i = 0; i < numDims; ++i) {
const ForStmt *forStmt = dyn_cast<ForStmt>(forStmts[i]);
bool mlir::addIndexSet(ArrayRef<const MLValue *> indices,
FlatAffineConstraints *domain) {
unsigned numIds = indices.size();
for (unsigned i = 0; i < numIds; ++i) {
const ForStmt *forStmt = dyn_cast<ForStmt>(indices[i]);
if (!forStmt || !forStmt->hasConstantBounds())
return false;
// Zero fill
std::fill(ineq.begin(), ineq.end(), 0);
// TODO(andydavis, bondhugula) Add methods for addUpper/LowerBound.
// Add inequality for lower bound.
ineq[i] = 1;
ineq[numIds] = -forStmt->getConstantLowerBound();
domain->addInequality(ineq);
domain->addConstantLowerBound(i, forStmt->getConstantLowerBound());
// Add inequality for upper bound.
ineq[i] = -1;
ineq[numIds] = forStmt->getConstantUpperBound();
domain->addInequality(ineq);
domain->addConstantUpperBound(i, forStmt->getConstantUpperBound());
}
return true;
}
@ -476,13 +468,13 @@ struct IterationDomainContext {
// TODO(andydavis) Capture the context of the symbols. For example, check
// if a symbol is the result of a constant operation, and set the symbol to
// that value in FlatAffineConstraints (using setIdToConstant).
static bool getIterationDomainContext(const OperationStmt *opStmt,
IterationDomainContext *ctx) {
bool getIterationDomainContext(const Statement *stmt,
IterationDomainContext *ctx) {
// Walk up tree storing parent statements in 'loops'.
// TODO(andydavis) Extend this to gather enclosing IfStmts and consider
// factoring it out into a utility function.
SmallVector<const ForStmt *, 4> loops;
const auto *currStmt = opStmt->getParentStmt();
const auto *currStmt = stmt->getParentStmt();
while (currStmt != nullptr) {
if (isa<IfStmt>(currStmt))
return false;
@ -510,7 +502,7 @@ static bool getIterationDomainContext(const OperationStmt *opStmt,
/*newNumReservedEqualities=*/0,
/*newNumReservedCols=*/numDims + numSymbols + 1, numDims,
numSymbols);
return addForStmtBounds(numDims, ctx->values, &ctx->domain);
return addIndexSet(ctx->values, &ctx->domain);
}
// Builds a map from MLValue to identifier position in a new merged identifier

View File

@ -578,7 +578,6 @@ void FlatAffineConstraints::addSymbolId(unsigned pos) {
/// Adds a dimensional identifier. The added column is initialized to
/// zero.
void FlatAffineConstraints::addId(IdKind kind, unsigned pos) {
assert(pos >= 0);
if (kind == IdKind::Dimension) {
assert(pos <= getNumDimIds());
} else if (kind == IdKind::Symbol) {
@ -645,7 +644,7 @@ void FlatAffineConstraints::addId(IdKind kind, unsigned pos) {
void FlatAffineConstraints::composeMap(AffineValueMap *vMap, unsigned pos) {
assert(vMap->getNumOperands() == getNumIds() && "inconsistent map");
assert(vMap->getNumDims() == getNumDimIds() && "inconsistent map");
assert(pos >= 0 && pos <= getNumIds() && "invalid position");
assert(pos <= getNumIds() && "invalid position");
AffineMap map = vMap->getAffineMap();
@ -1003,7 +1002,7 @@ void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
}
void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) {
assert(pos >= 0 && pos < getNumCols());
assert(pos < getNumCols());
unsigned offset = inequalities.size();
inequalities.resize(inequalities.size() + numReservedCols);
std::fill(inequalities.begin() + offset,
@ -1013,7 +1012,7 @@ void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) {
}
void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) {
assert(pos >= 0 && pos < getNumCols());
assert(pos < getNumCols());
unsigned offset = inequalities.size();
inequalities.resize(inequalities.size() + numReservedCols);
std::fill(inequalities.begin() + offset,
@ -1095,6 +1094,80 @@ void FlatAffineConstraints::removeEquality(unsigned pos) {
equalities.resize(equalities.size() - numReservedCols);
}
bool FlatAffineConstraints::getDimensionBounds(unsigned pos, unsigned num,
SmallVectorImpl<AffineMap> *lbs,
SmallVectorImpl<AffineMap> *ubs,
MLIRContext *context) {
assert(pos + num < getNumCols());
// Only constant dim bounds for now.
projectOut(0, pos);
projectOut(pos + num, getNumIds() - num);
lbs->resize(num, AffineMap::Null());
ubs->resize(num, AffineMap::Null());
for (int i = static_cast<int>(num) - 1; i >= 0; i--) {
auto lb = getConstantLowerBound(i);
auto ub = getConstantUpperBound(i);
// TODO(mlir-team): handle arbitrary bounds.
if (!lb.hasValue() || !ub.hasValue())
return false;
(*lbs)[i] = AffineMap::getConstantMap(lb.getValue(), context);
(*ubs)[i] = AffineMap::getConstantMap(ub.getValue(), context);
projectOut(i, 1);
}
return true;
}
Optional<int64_t> FlatAffineConstraints::getConstantLowerBound(unsigned pos) {
assert(pos < getNumCols() - 1);
Optional<int64_t> lb = None;
for (unsigned r = 0; r < getNumInequalities(); r++) {
if (atIneq(r, pos) <= 0)
// Not a lower bound.
continue;
unsigned c;
for (c = 0; c < getNumCols() - 1; c++) {
if (c != pos && atIneq(r, c) != 0)
break;
}
// Not a constant lower bound.
if (c < getNumCols() - 1)
return None;
auto mayLb = mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, pos));
if (!lb.hasValue() || mayLb < lb.getValue())
lb = mayLb;
}
// TODO(andydavis,bondhugula): consider equalities (and an equality
// contradicting an inequality, i.e, an empty set).
return lb;
}
Optional<int64_t> FlatAffineConstraints::getConstantUpperBound(unsigned pos) {
assert(pos < getNumCols() - 1);
Optional<int64_t> ub = None;
for (unsigned r = 0; r < getNumInequalities(); r++) {
// Not a upper bound.
if (atIneq(r, pos) >= 0)
continue;
unsigned c;
for (c = 0; c < getNumCols() - 1; c++) {
if (c != pos && atIneq(r, c) != 0)
break;
}
// Not a constant upper bound.
if (c < getNumCols() - 1)
return None;
auto mayUb = mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, pos));
if (!ub.hasValue() || mayUb > ub.getValue())
ub = mayUb;
}
// TODO(andydavis,bondhugula): consider equalities (and an equality
// contradicting an inequality, i.e, an empty set).
return ub;
}
void FlatAffineConstraints::print(raw_ostream &os) const {
assert(inequalities.size() == getNumInequalities() * numReservedCols);
assert(equalities.size() == getNumEqualities() * numReservedCols);
@ -1127,7 +1200,7 @@ void FlatAffineConstraints::clearAndCopyFrom(
}
void FlatAffineConstraints::removeId(unsigned pos) {
assert(pos >= 0 && pos < getNumIds());
assert(pos < getNumIds());
if (pos < numDims)
numDims--;
@ -1363,7 +1436,7 @@ void FlatAffineConstraints::FourierMotzkinEliminate(
void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
// 'pos' can be at most getNumCols() - 2.
assert(pos >= 0 && pos <= getNumCols() - 2 && "invalid range");
assert(pos <= getNumCols() - 2 && "invalid position");
assert(pos + num < getNumCols() && "invalid range");
for (unsigned i = 0; i < num; i++) {
FourierMotzkinEliminate(pos);