forked from OSchip/llvm-project
Clean up memref dep check utilities; update FlatAffineConstraints API, add
simple utility methods. - clean up some of the analysis utilities used by memref dep checking - add additional asserts / comments at places in analysis utilities - add additional simple methods to the FlatAffineConstraints API. PiperOrigin-RevId: 220124523
This commit is contained in:
parent
9a62178372
commit
4269a01863
|
@ -37,6 +37,7 @@ class MLIRContext;
|
|||
class FlatAffineConstraints;
|
||||
class MLValue;
|
||||
class OperationStmt;
|
||||
class Statement;
|
||||
|
||||
/// Simplify an affine expression through flattening and some amount of
|
||||
/// simple analysis. This has complexity linear in the number of nodes in
|
||||
|
@ -66,6 +67,11 @@ bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
|
|||
llvm::SmallVectorImpl<int64_t> *flattenedExpr,
|
||||
FlatAffineConstraints *cst = nullptr);
|
||||
|
||||
/// Adds constraints capturing the index set of the ML values in indices to
|
||||
/// 'domain'.
|
||||
bool addIndexSet(llvm::ArrayRef<const MLValue *> indices,
|
||||
FlatAffineConstraints *domain);
|
||||
|
||||
/// Checks whether two accesses to the same memref access the same element.
|
||||
/// Each access is specified using the MemRefAccess structure, which contains
|
||||
/// the operation statement, indices and memref associated with the access.
|
||||
|
|
|
@ -456,6 +456,23 @@ public:
|
|||
/// Clears this list of constraints and copies other into it.
|
||||
void clearAndCopyFrom(const FlatAffineConstraints &other);
|
||||
|
||||
/// Returns the constant lower bound of the specified identifier (through a
|
||||
/// scan through the constraints); returns None if the bound isn't trivially a
|
||||
/// constant.
|
||||
Optional<int64_t> getConstantLowerBound(unsigned pos);
|
||||
|
||||
/// Returns the constant upper bound of the specified identifier (through a
|
||||
/// scan through the constraints); returns None if the bound isn't trivially a
|
||||
/// constant.
|
||||
Optional<int64_t> getConstantUpperBound(unsigned pos);
|
||||
|
||||
// Returns the lower and upper bounds of the specified dimensions as
|
||||
// AffineMap's. Returns false for the unimplemented cases for the moment.
|
||||
bool getDimensionBounds(unsigned pos, unsigned num,
|
||||
SmallVectorImpl<AffineMap> *lbs,
|
||||
SmallVectorImpl<AffineMap> *ubs,
|
||||
MLIRContext *context);
|
||||
|
||||
// More expensive ones.
|
||||
void removeDuplicates();
|
||||
|
||||
|
|
|
@ -216,11 +216,13 @@ public:
|
|||
void visitDimExpr(AffineDimExpr expr) {
|
||||
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
|
||||
auto &eq = operandExprStack.back();
|
||||
assert(expr.getPosition() < numDims && "Inconsistent number of dims");
|
||||
eq[getDimStartIndex() + expr.getPosition()] = 1;
|
||||
}
|
||||
void visitSymbolExpr(AffineSymbolExpr expr) {
|
||||
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
|
||||
auto &eq = operandExprStack.back();
|
||||
assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
|
||||
eq[getSymbolStartIndex() + expr.getPosition()] = 1;
|
||||
}
|
||||
void visitConstantExpr(AffineConstantExpr expr) {
|
||||
|
@ -283,6 +285,8 @@ private:
|
|||
bound[bound.size() - 1] = -(rhsConst - 1);
|
||||
cst.addLowerBound(lhs, bound);
|
||||
}
|
||||
// Set the expression on stack to the local var introduced to capture the
|
||||
// result of the division (floor or ceil).
|
||||
std::fill(lhs.begin(), lhs.end(), 0);
|
||||
lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
|
||||
}
|
||||
|
@ -420,29 +424,17 @@ void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) {
|
|||
// TODO(andydavis) Handle non-unit Step by adding local variable
|
||||
// (iv - lb % step = 0 introducing a method in FlatAffineConstraints
|
||||
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
|
||||
static bool addForStmtBounds(unsigned numDims,
|
||||
ArrayRef<const MLValue *> forStmts,
|
||||
FlatAffineConstraints *domain) {
|
||||
assert(forStmts.size() >= numDims);
|
||||
unsigned numIds = forStmts.size();
|
||||
// Add InEqualties for loop bounds.
|
||||
SmallVector<int64_t, 4> ineq;
|
||||
ineq.resize(numIds + 1);
|
||||
for (unsigned i = 0; i < numDims; ++i) {
|
||||
const ForStmt *forStmt = dyn_cast<ForStmt>(forStmts[i]);
|
||||
bool mlir::addIndexSet(ArrayRef<const MLValue *> indices,
|
||||
FlatAffineConstraints *domain) {
|
||||
unsigned numIds = indices.size();
|
||||
for (unsigned i = 0; i < numIds; ++i) {
|
||||
const ForStmt *forStmt = dyn_cast<ForStmt>(indices[i]);
|
||||
if (!forStmt || !forStmt->hasConstantBounds())
|
||||
return false;
|
||||
// Zero fill
|
||||
std::fill(ineq.begin(), ineq.end(), 0);
|
||||
// TODO(andydavis, bondhugula) Add methods for addUpper/LowerBound.
|
||||
// Add inequality for lower bound.
|
||||
ineq[i] = 1;
|
||||
ineq[numIds] = -forStmt->getConstantLowerBound();
|
||||
domain->addInequality(ineq);
|
||||
domain->addConstantLowerBound(i, forStmt->getConstantLowerBound());
|
||||
// Add inequality for upper bound.
|
||||
ineq[i] = -1;
|
||||
ineq[numIds] = forStmt->getConstantUpperBound();
|
||||
domain->addInequality(ineq);
|
||||
domain->addConstantUpperBound(i, forStmt->getConstantUpperBound());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -476,13 +468,13 @@ struct IterationDomainContext {
|
|||
// TODO(andydavis) Capture the context of the symbols. For example, check
|
||||
// if a symbol is the result of a constant operation, and set the symbol to
|
||||
// that value in FlatAffineConstraints (using setIdToConstant).
|
||||
static bool getIterationDomainContext(const OperationStmt *opStmt,
|
||||
IterationDomainContext *ctx) {
|
||||
bool getIterationDomainContext(const Statement *stmt,
|
||||
IterationDomainContext *ctx) {
|
||||
// Walk up tree storing parent statements in 'loops'.
|
||||
// TODO(andydavis) Extend this to gather enclosing IfStmts and consider
|
||||
// factoring it out into a utility function.
|
||||
SmallVector<const ForStmt *, 4> loops;
|
||||
const auto *currStmt = opStmt->getParentStmt();
|
||||
const auto *currStmt = stmt->getParentStmt();
|
||||
while (currStmt != nullptr) {
|
||||
if (isa<IfStmt>(currStmt))
|
||||
return false;
|
||||
|
@ -510,7 +502,7 @@ static bool getIterationDomainContext(const OperationStmt *opStmt,
|
|||
/*newNumReservedEqualities=*/0,
|
||||
/*newNumReservedCols=*/numDims + numSymbols + 1, numDims,
|
||||
numSymbols);
|
||||
return addForStmtBounds(numDims, ctx->values, &ctx->domain);
|
||||
return addIndexSet(ctx->values, &ctx->domain);
|
||||
}
|
||||
|
||||
// Builds a map from MLValue to identifier position in a new merged identifier
|
||||
|
|
|
@ -578,7 +578,6 @@ void FlatAffineConstraints::addSymbolId(unsigned pos) {
|
|||
/// Adds a dimensional identifier. The added column is initialized to
|
||||
/// zero.
|
||||
void FlatAffineConstraints::addId(IdKind kind, unsigned pos) {
|
||||
assert(pos >= 0);
|
||||
if (kind == IdKind::Dimension) {
|
||||
assert(pos <= getNumDimIds());
|
||||
} else if (kind == IdKind::Symbol) {
|
||||
|
@ -645,7 +644,7 @@ void FlatAffineConstraints::addId(IdKind kind, unsigned pos) {
|
|||
void FlatAffineConstraints::composeMap(AffineValueMap *vMap, unsigned pos) {
|
||||
assert(vMap->getNumOperands() == getNumIds() && "inconsistent map");
|
||||
assert(vMap->getNumDims() == getNumDimIds() && "inconsistent map");
|
||||
assert(pos >= 0 && pos <= getNumIds() && "invalid position");
|
||||
assert(pos <= getNumIds() && "invalid position");
|
||||
|
||||
AffineMap map = vMap->getAffineMap();
|
||||
|
||||
|
@ -1003,7 +1002,7 @@ void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
|
|||
}
|
||||
|
||||
void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) {
|
||||
assert(pos >= 0 && pos < getNumCols());
|
||||
assert(pos < getNumCols());
|
||||
unsigned offset = inequalities.size();
|
||||
inequalities.resize(inequalities.size() + numReservedCols);
|
||||
std::fill(inequalities.begin() + offset,
|
||||
|
@ -1013,7 +1012,7 @@ void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) {
|
|||
}
|
||||
|
||||
void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) {
|
||||
assert(pos >= 0 && pos < getNumCols());
|
||||
assert(pos < getNumCols());
|
||||
unsigned offset = inequalities.size();
|
||||
inequalities.resize(inequalities.size() + numReservedCols);
|
||||
std::fill(inequalities.begin() + offset,
|
||||
|
@ -1095,6 +1094,80 @@ void FlatAffineConstraints::removeEquality(unsigned pos) {
|
|||
equalities.resize(equalities.size() - numReservedCols);
|
||||
}
|
||||
|
||||
bool FlatAffineConstraints::getDimensionBounds(unsigned pos, unsigned num,
|
||||
SmallVectorImpl<AffineMap> *lbs,
|
||||
SmallVectorImpl<AffineMap> *ubs,
|
||||
MLIRContext *context) {
|
||||
assert(pos + num < getNumCols());
|
||||
|
||||
// Only constant dim bounds for now.
|
||||
projectOut(0, pos);
|
||||
projectOut(pos + num, getNumIds() - num);
|
||||
|
||||
lbs->resize(num, AffineMap::Null());
|
||||
ubs->resize(num, AffineMap::Null());
|
||||
|
||||
for (int i = static_cast<int>(num) - 1; i >= 0; i--) {
|
||||
auto lb = getConstantLowerBound(i);
|
||||
auto ub = getConstantUpperBound(i);
|
||||
// TODO(mlir-team): handle arbitrary bounds.
|
||||
if (!lb.hasValue() || !ub.hasValue())
|
||||
return false;
|
||||
(*lbs)[i] = AffineMap::getConstantMap(lb.getValue(), context);
|
||||
(*ubs)[i] = AffineMap::getConstantMap(ub.getValue(), context);
|
||||
projectOut(i, 1);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Optional<int64_t> FlatAffineConstraints::getConstantLowerBound(unsigned pos) {
|
||||
assert(pos < getNumCols() - 1);
|
||||
Optional<int64_t> lb = None;
|
||||
for (unsigned r = 0; r < getNumInequalities(); r++) {
|
||||
if (atIneq(r, pos) <= 0)
|
||||
// Not a lower bound.
|
||||
continue;
|
||||
unsigned c;
|
||||
for (c = 0; c < getNumCols() - 1; c++) {
|
||||
if (c != pos && atIneq(r, c) != 0)
|
||||
break;
|
||||
}
|
||||
// Not a constant lower bound.
|
||||
if (c < getNumCols() - 1)
|
||||
return None;
|
||||
auto mayLb = mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, pos));
|
||||
if (!lb.hasValue() || mayLb < lb.getValue())
|
||||
lb = mayLb;
|
||||
}
|
||||
// TODO(andydavis,bondhugula): consider equalities (and an equality
|
||||
// contradicting an inequality, i.e, an empty set).
|
||||
return lb;
|
||||
}
|
||||
|
||||
Optional<int64_t> FlatAffineConstraints::getConstantUpperBound(unsigned pos) {
|
||||
assert(pos < getNumCols() - 1);
|
||||
Optional<int64_t> ub = None;
|
||||
for (unsigned r = 0; r < getNumInequalities(); r++) {
|
||||
// Not a upper bound.
|
||||
if (atIneq(r, pos) >= 0)
|
||||
continue;
|
||||
unsigned c;
|
||||
for (c = 0; c < getNumCols() - 1; c++) {
|
||||
if (c != pos && atIneq(r, c) != 0)
|
||||
break;
|
||||
}
|
||||
// Not a constant upper bound.
|
||||
if (c < getNumCols() - 1)
|
||||
return None;
|
||||
auto mayUb = mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, pos));
|
||||
if (!ub.hasValue() || mayUb > ub.getValue())
|
||||
ub = mayUb;
|
||||
}
|
||||
// TODO(andydavis,bondhugula): consider equalities (and an equality
|
||||
// contradicting an inequality, i.e, an empty set).
|
||||
return ub;
|
||||
}
|
||||
|
||||
void FlatAffineConstraints::print(raw_ostream &os) const {
|
||||
assert(inequalities.size() == getNumInequalities() * numReservedCols);
|
||||
assert(equalities.size() == getNumEqualities() * numReservedCols);
|
||||
|
@ -1127,7 +1200,7 @@ void FlatAffineConstraints::clearAndCopyFrom(
|
|||
}
|
||||
|
||||
void FlatAffineConstraints::removeId(unsigned pos) {
|
||||
assert(pos >= 0 && pos < getNumIds());
|
||||
assert(pos < getNumIds());
|
||||
|
||||
if (pos < numDims)
|
||||
numDims--;
|
||||
|
@ -1363,7 +1436,7 @@ void FlatAffineConstraints::FourierMotzkinEliminate(
|
|||
|
||||
void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
|
||||
// 'pos' can be at most getNumCols() - 2.
|
||||
assert(pos >= 0 && pos <= getNumCols() - 2 && "invalid range");
|
||||
assert(pos <= getNumCols() - 2 && "invalid position");
|
||||
assert(pos + num < getNumCols() && "invalid range");
|
||||
for (unsigned i = 0; i < num; i++) {
|
||||
FourierMotzkinEliminate(pos);
|
||||
|
|
Loading…
Reference in New Issue