Simplify memref-dependence-check's meta data structures / drop duplication and

reuse existing ones.

- drop IterationDomainContext, redundant since FlatAffineConstraints has
  MLValue information associated with its dimensions.
- refactor to use existing support
- leads to a reduction in LOC
- as a result of these changes, non-constant loop bounds get naturally
  supported for dep analysis.
- update test cases to include a couple with non-constant loop bounds
- rename addBoundsFromForStmt -> addForStmtDomain
- complete TODO for getLoopIVs (handle 'if' statements)

PiperOrigin-RevId: 226082008
This commit is contained in:
Uday Bondhugula 2018-12-18 16:38:24 -08:00 committed by jpienaar
parent 1d72f2e47e
commit 14d2618f63
8 changed files with 152 additions and 166 deletions

View File

@ -112,9 +112,13 @@ bool getFlattenedAffineExprs(
IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
FlatAffineConstraints *cst = nullptr);
/// Adds constraints capturing the index set of the ML values in indices to
/// 'domain'.
bool addIndexSet(llvm::ArrayRef<const MLValue *> indices,
/// Builds a system of constraints with dimensional identifiers corresponding to
/// the loop IVs of the forStmts appearing in that order. Bounds of the loop are
/// used to add appropriate inequalities. Any symbols founds in the bound
/// operands are added as symbols in the system. Returns false for the yet
/// unimplemented cases.
// TODO(bondhugula): handle non-unit strides.
bool getIndexSet(llvm::ArrayRef<ForStmt *> forStmts,
FlatAffineConstraints *domain);
struct MemRefAccess {

View File

@ -417,11 +417,12 @@ public:
/// right identifier is first looked up using forStmt's MLValue. Returns
/// false for the yet unimplemented/unsupported cases, and true if the
/// information is succesfully added. Asserts if the MLValue corresponding to
/// the 'for' statement isn't found in the constaint system. Any new
/// the 'for' statement isn't found in the constraint system. Any new
/// identifiers that are found in the bound operands of the 'for' statement
/// are added as trailing identifiers (either dimensional or symbolic
/// depending on whether the operand is a valid MLFunction symbol).
bool addBoundsFromForStmt(const ForStmt &forStmt);
// TODO(bondhugula): add support for non-unit strides.
bool addForStmtDomain(const ForStmt &forStmt);
/// Adds an upper bound expression for the specified expression.
void addUpperBound(ArrayRef<int64_t> expr, ArrayRef<int64_t> ub);
@ -513,6 +514,24 @@ public:
return {ids.data(), ids.size()};
}
/// Returns the MLValue's associated with the identifiers. Asserts if
/// no MLValue was associated with an identifier.
inline void getIdValues(SmallVectorImpl<MLValue *> *values) const {
values->clear();
values->reserve(numIds);
for (unsigned i = 0; i < numIds; i++) {
assert(ids[i].hasValue() && "identifier's MLValue not set");
values->push_back(ids[i].getValue());
}
}
/// Returns the MLValue associated with the pos^th identifier. Asserts if
/// no MLValue identifier was associated.
inline MLValue *getIdValue(unsigned pos) const {
assert(ids[pos].hasValue() && "identifier's ML Value not set");
return ids[pos].getValue();
}
/// Clears this list of constraints and copies other into it.
void clearAndCopyFrom(const FlatAffineConstraints &other);

View File

@ -47,6 +47,7 @@ bool properlyDominates(const Statement &a, const Statement &b);
/// Populates 'loops' with IVs of the loops surrounding 'stmt' ordered from
/// the outermost 'for' statement to the innermost one.
// TODO(bondhugula): handle 'if' stmt's.
void getLoopIVs(const Statement &stmt, SmallVectorImpl<ForStmt *> *loops);
/// A region of a memref's data space; this is typically constructed by

View File

@ -547,78 +547,40 @@ void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) {
}
}
// Adds loop upper and lower bound inequalities to 'domain' for each ForStmt
// value in 'forStmts'. Requires that the first 'numDims' MLValues in 'forStmts'
// are ForStmts. Returns true if lower/upper bound inequalities were
// successfully added, returns false otherwise.
// TODO(andydavis) Get operands for loop bounds so we can add domain
// constraints for non-constant loop bounds.
// TODO(andydavis) Handle non-unit Step by adding local variable
// (iv - lb % step = 0 introducing a method in FlatAffineConstraints
// Builds a system of constraints with dimensional identifiers corresponding to
// the loop IVs of the forStmts appearing in that order. Any symbols founds in
// the bound operands are added as symbols in the system. Returns false for the
// yet unimplemented cases.
// TODO(andydavis,bondhugula) Handle non-unit steps through local variables or
// stride information in FlatAffineConstraints. (For eg., by using iv - lb %
// step = 0 and/or by introducing a method in FlatAffineConstraints
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
bool mlir::addIndexSet(ArrayRef<const MLValue *> indices,
bool mlir::getIndexSet(ArrayRef<ForStmt *> forStmts,
FlatAffineConstraints *domain) {
unsigned numIds = indices.size();
for (unsigned i = 0; i < numIds; ++i) {
const ForStmt *forStmt = dyn_cast<ForStmt>(indices[i]);
if (!forStmt || !forStmt->hasConstantBounds())
SmallVector<MLValue *, 4> indices(forStmts.begin(), forStmts.end());
// Reset while associated MLValues in 'indices' to the domain.
domain->reset(forStmts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
for (auto *forStmt : forStmts) {
// Add constraints from forStmt's bounds.
if (!domain->addForStmtDomain(*forStmt))
return false;
// Add inequalities from forStmt bounds.
domain->addBoundsFromForStmt(*forStmt);
}
return true;
}
// IterationDomainContext encapsulates the state required to represent
// the iteration domain of an OperationStmt.
// TODO(andydavis) Move this into FlatAffineConstraints when we have shared
// code to manage the operand values and positions to use FlatAffineConstraints
// and AffineValueMap.
struct IterationDomainContext {
// Set of inequality constraint pairs, where each pair represents the
// upper/lower bounds of a ForStmt in the iteration domain.
FlatAffineConstraints domain;
// The number of dimension identifiers in 'values'.
unsigned numDims;
// The list of MLValues in this iteration domain, with MLValues in
// [0, numDims) representing dimension identifiers, and MLValues in
// [numDims, values.size()) representing symbol identifiers.
SmallVector<MLValue *, 4> values;
IterationDomainContext() : numDims(0) {}
unsigned getNumDims() const { return numDims; }
unsigned getNumSymbols() const { return values.size() - numDims; }
};
// Computes the iteration domain for 'opStmt' and populates 'ctx', which
// encapsulates the following state for each ForStmt in 'opStmt's iteration
// domain:
// *) adds inequality constraints representing the ForStmt upper/lower bounds.
// *) adds MLValues and symbols for the ForStmt and its operands to a list.
// TODO(andydavis) Add support for IfStmts in iteration domain.
// TODO(andydavis) Handle non-constant loop bounds by composing affine maps
// for each ForStmt loop bound and adding de-duped ids/symbols to iteration
// domain context.
static bool getIterationDomainContext(const Statement *stmt,
IterationDomainContext *ctx) {
// Computes the iteration domain for 'opStmt' and populates 'indexSet', which
// encapsulates the constraints involving loops surrounding 'opStmt' and
// potentially involving any MLFunction symbols. The dimensional identifiers in
// 'indexSet' correspond to the loops surounding 'stmt' from outermost to
// innermost.
// TODO(andydavis) Add support to handle IfStmts surrounding 'stmt'.
static bool getStmtIndexSet(const Statement *stmt,
FlatAffineConstraints *indexSet) {
// TODO(andydavis) Extend this to gather enclosing IfStmts and consider
// factoring it out into a utility function.
SmallVector<ForStmt *, 4> loops;
getLoopIVs(*stmt, &loops);
// Iterate through 'loops' from outer-most loop to inner-most loop.
// Populate 'values'.
ctx->values.reserve(loops.size());
ctx->numDims += loops.size();
ctx->values.insert(ctx->values.end(), loops.begin(), loops.end());
// Resize flat affine constraint system based on num dims symbols found.
unsigned numDims = ctx->getNumDims();
unsigned numSymbols = ctx->getNumSymbols();
ctx->domain.reset(/*newNumReservedInequalities=*/2 * numDims,
/*newNumReservedEqualities=*/0,
/*newNumReservedCols=*/numDims + numSymbols + 1, numDims,
numSymbols, /*numLocals=*/0, /*idArgs=*/ctx->values);
return addIndexSet(ctx->values, &ctx->domain);
return getIndexSet(loops, indexSet);
}
// ValuePositionMap manages the mapping from MLValues which represent dimension
@ -708,15 +670,13 @@ private:
// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers]
//
// This method populates 'valuePosMap' with mappings from operand MLValues in
// 'srcAccessMap'/'dstAccessMap' (as well as those in
// 'srcIterationDomainContext'/'dstIterationDomainContext') to the position of
// these values in the merged list.
// 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain')
// to the position of these values in the merged list.
static void buildDimAndSymbolPositionMaps(
const IterationDomainContext &srcIterationDomainContext,
const IterationDomainContext &dstIterationDomainContext,
const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap,
ValuePositionMap *valuePosMap) {
auto updateValuePosMap = [&](ArrayRef<const MLValue *> values, bool isSrc) {
const FlatAffineConstraints &srcDomain,
const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap) {
auto updateValuePosMap = [&](ArrayRef<MLValue *> values, bool isSrc) {
for (unsigned i = 0, e = values.size(); i < e; ++i) {
auto *value = values[i];
if (!isa<ForStmt>(values[i]))
@ -728,10 +688,14 @@ static void buildDimAndSymbolPositionMaps(
}
};
SmallVector<MLValue *, 4> srcValues, destValues;
srcDomain.getIdValues(&srcValues);
dstDomain.getIdValues(&destValues);
// Update value position map with identifiers from src iteration domain.
updateValuePosMap(srcIterationDomainContext.values, /*isSrc=*/true);
updateValuePosMap(srcValues, /*isSrc=*/true);
// Update value position map with identifiers from dst iteration domain.
updateValuePosMap(dstIterationDomainContext.values, /*isSrc=*/false);
updateValuePosMap(destValues, /*isSrc=*/false);
// Update value position map with identifiers from src access function.
updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true);
// Update value position map with identifiers from dst access function.
@ -745,22 +709,22 @@ static unsigned getPos(const DenseMap<const MLValue *, unsigned> &posMap,
return it->second;
}
// Adds iteration domain constraints from 'srcCtx' and 'dstCtx' into
// Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into
// 'dependenceDomain'.
// Uses 'valuePosMap' to map from operand values in 'ctx.values' to position in
// 'dependenceDomain'.
static void addDomainConstraints(const IterationDomainContext &srcCtx,
const IterationDomainContext &dstCtx,
// Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a
// srcDomain/dstDomain MLValue maps.
static void addDomainConstraints(const FlatAffineConstraints &srcDomain,
const FlatAffineConstraints &dstDomain,
const ValuePositionMap &valuePosMap,
FlatAffineConstraints *dependenceDomain) {
unsigned srcNumIneq = srcCtx.domain.getNumInequalities();
unsigned srcNumDims = srcCtx.domain.getNumDimIds();
unsigned srcNumSymbols = srcCtx.domain.getNumSymbolIds();
unsigned srcNumIneq = srcDomain.getNumInequalities();
unsigned srcNumDims = srcDomain.getNumDimIds();
unsigned srcNumSymbols = srcDomain.getNumSymbolIds();
unsigned srcNumIds = srcNumDims + srcNumSymbols;
unsigned dstNumIneq = dstCtx.domain.getNumInequalities();
unsigned dstNumDims = dstCtx.domain.getNumDimIds();
unsigned dstNumSymbols = dstCtx.domain.getNumSymbolIds();
unsigned dstNumIneq = dstDomain.getNumInequalities();
unsigned dstNumDims = dstDomain.getNumDimIds();
unsigned dstNumSymbols = dstDomain.getNumSymbolIds();
unsigned dstNumIds = dstNumDims + dstNumSymbols;
SmallVector<int64_t, 4> ineq(dependenceDomain->getNumCols());
@ -770,10 +734,10 @@ static void addDomainConstraints(const IterationDomainContext &srcCtx,
std::fill(ineq.begin(), ineq.end(), 0);
// Set coefficients for identifiers corresponding to src domain.
for (unsigned j = 0; j < srcNumIds; ++j)
ineq[valuePosMap.getSrcDimOrSymPos(srcCtx.values[j])] =
srcCtx.domain.atIneq(i, j);
ineq[valuePosMap.getSrcDimOrSymPos(srcDomain.getIdValue(j))] =
srcDomain.atIneq(i, j);
// Set constant term.
ineq[ineq.size() - 1] = srcCtx.domain.atIneq(i, srcNumIds);
ineq[ineq.size() - 1] = srcDomain.atIneq(i, srcNumIds);
// Add inequality constraint.
dependenceDomain->addInequality(ineq);
}
@ -783,10 +747,10 @@ static void addDomainConstraints(const IterationDomainContext &srcCtx,
std::fill(ineq.begin(), ineq.end(), 0);
// Set coefficients for identifiers corresponding to dst domain.
for (unsigned j = 0; j < dstNumIds; ++j)
ineq[valuePosMap.getDstDimOrSymPos(dstCtx.values[j])] =
dstCtx.domain.atIneq(i, j);
ineq[valuePosMap.getDstDimOrSymPos(dstDomain.getIdValue(j))] =
dstDomain.atIneq(i, j);
// Set constant term.
ineq[ineq.size() - 1] = dstCtx.domain.atIneq(i, dstNumIds);
ineq[ineq.size() - 1] = dstDomain.atIneq(i, dstNumIds);
// Add inequality constraint.
dependenceDomain->addInequality(ineq);
}
@ -908,19 +872,17 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
return true;
}
// Returns the number of outer loop common to 'src/dstIterationDomainContext'.
static unsigned
getNumCommonLoops(const IterationDomainContext &srcIterationDomainContext,
const IterationDomainContext &dstIterationDomainContext) {
// Returns the number of outer loop common to 'src/dstDomain'.
static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain,
const FlatAffineConstraints &dstDomain) {
// Find the number of common loops shared by src and dst accesses.
unsigned minNumLoops = std::min(srcIterationDomainContext.getNumDims(),
dstIterationDomainContext.getNumDims());
unsigned minNumLoops =
std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds());
unsigned numCommonLoops = 0;
for (unsigned i = 0; i < minNumLoops; ++i) {
if (!isa<ForStmt>(srcIterationDomainContext.values[i]) ||
!isa<ForStmt>(dstIterationDomainContext.values[i]) ||
srcIterationDomainContext.values[i] !=
dstIterationDomainContext.values[i])
if (!isa<ForStmt>(srcDomain.getIdValue(i)) ||
!isa<ForStmt>(dstDomain.getIdValue(i)) ||
srcDomain.getIdValue(i) != dstDomain.getIdValue(i))
break;
++numCommonLoops;
}
@ -931,15 +893,14 @@ getNumCommonLoops(const IterationDomainContext &srcIterationDomainContext,
// the operation statement in 'dstAccess'. Returns false otherwise.
// Note that 'numCommonLoops' is the number of contiguous surrounding outer
// loops.
static bool
srcHappensBeforeDst(const MemRefAccess &srcAccess,
const MemRefAccess &dstAccess,
const IterationDomainContext &srcIterationDomainContext,
unsigned numCommonLoops) {
static bool srcHappensBeforeDst(const MemRefAccess &srcAccess,
const MemRefAccess &dstAccess,
const FlatAffineConstraints &srcDomain,
unsigned numCommonLoops) {
if (numCommonLoops == 0) {
return mlir::properlyDominates(*srcAccess.opStmt, *dstAccess.opStmt);
}
auto *commonForValue = srcIterationDomainContext.values[numCommonLoops - 1];
auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
assert(isa<ForStmt>(commonForValue));
auto *commonForStmt = dyn_cast<ForStmt>(commonForValue);
// Check the dominance relationship between the respective ancestors of the
@ -952,22 +913,20 @@ srcHappensBeforeDst(const MemRefAccess &srcAccess,
}
// Adds ordering constraints to 'dependenceDomain' based on number of loops
// common to 'src/dstIterationDomainContext' and requested 'loopDepth'.
// common to 'src/dstDomain' and requested 'loopDepth'.
// Note that 'loopDepth' cannot exceed the number of common loops plus one.
// EX: Given a loop nest of depth 2 with IVs 'i' and 'j':
// *) If 'loopDepth == 1' then one constraint is added: i' >= i + 1
// *) If 'loopDepth == 2' then two constraints are added: i == i' and j' > j + 1
// *) If 'loopDepth == 3' then two constraints are added: i == i' and j == j'
static void
addOrderingConstraints(const IterationDomainContext &srcIterationDomainContext,
const IterationDomainContext &dstIterationDomainContext,
const ValuePositionMap &valuePosMap, unsigned loopDepth,
FlatAffineConstraints *dependenceDomain) {
static void addOrderingConstraints(const FlatAffineConstraints &srcDomain,
const FlatAffineConstraints &dstDomain,
unsigned loopDepth,
FlatAffineConstraints *dependenceDomain) {
unsigned numCols = dependenceDomain->getNumCols();
SmallVector<int64_t, 4> eq(numCols);
unsigned numSrcDims = valuePosMap.getNumSrcDims();
unsigned numCommonLoops =
getNumCommonLoops(srcIterationDomainContext, dstIterationDomainContext);
unsigned numSrcDims = srcDomain.getNumDimIds();
unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
unsigned numCommonLoopConstraints = std::min(numCommonLoops, loopDepth);
for (unsigned i = 0; i < numCommonLoopConstraints; ++i) {
std::fill(eq.begin(), eq.end(), 0);
@ -1002,13 +961,12 @@ static bool hasSingleNonZeroAt(unsigned idPos, unsigned rowIdx, bool isEq,
// eliminating all other variables, and reading off distance vectors from
// equality constraints (if possible), and direction vectors from inequalities.
static void computeDirectionVector(
const IterationDomainContext &srcIterationDomainContext,
const IterationDomainContext &dstIterationDomainContext, unsigned loopDepth,
const FlatAffineConstraints &srcDomain,
const FlatAffineConstraints &dstDomain, unsigned loopDepth,
FlatAffineConstraints *dependenceDomain,
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents) {
// Find the number of common loops shared by src and dst accesses.
unsigned numCommonLoops =
getNumCommonLoops(srcIterationDomainContext, dstIterationDomainContext);
unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
if (numCommonLoops == 0)
return;
// Compute direction vectors for requested loop depth.
@ -1119,7 +1077,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
// with each access.
// *) Build dimension and symbol position maps for each access, which map
// MLValues from access functions and iteration domains to their position
// in the merged constraint system build by this method.
// in the merged constraint system built by this method.
//
// This method builds a constraint system with the following column format:
//
@ -1204,41 +1162,40 @@ bool mlir::checkMemrefAccessDependence(
AffineValueMap dstAccessMap;
dstAccess.getAccessMap(&dstAccessMap);
// Get iteration domain context for 'srcAccess'.
IterationDomainContext srcIterationDomainContext;
if (!getIterationDomainContext(srcAccess.opStmt, &srcIterationDomainContext))
// Get iteration domain for the 'srcAccess' statement.
FlatAffineConstraints srcDomain;
if (!getStmtIndexSet(srcAccess.opStmt, &srcDomain))
return false;
// Get iteration domain context for 'dstAccess'.
IterationDomainContext dstIterationDomainContext;
if (!getIterationDomainContext(dstAccess.opStmt, &dstIterationDomainContext))
// Get iteration domain for 'dstAccess' statement.
FlatAffineConstraints dstDomain;
if (!getStmtIndexSet(dstAccess.opStmt, &dstDomain))
return false;
// Return if loopDepth > numCommonLoops and 'srcAccess' does not properly
// dominate 'dstAccess' (i.e. no execution path from src to dst access).
unsigned numCommonLoops =
getNumCommonLoops(srcIterationDomainContext, dstIterationDomainContext);
unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
assert(loopDepth <= numCommonLoops + 1);
if (loopDepth > numCommonLoops &&
!srcHappensBeforeDst(srcAccess, dstAccess, srcIterationDomainContext,
numCommonLoops)) {
!srcHappensBeforeDst(srcAccess, dstAccess, srcDomain, numCommonLoops)) {
return false;
}
// Build dim and symbol position maps for each access from access operand
// MLValue to position in merged contstraint system.
ValuePositionMap valuePosMap;
buildDimAndSymbolPositionMaps(srcIterationDomainContext,
dstIterationDomainContext, srcAccessMap,
buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
dstAccessMap, &valuePosMap);
assert(valuePosMap.getNumDims() ==
srcDomain.getNumDimIds() + dstDomain.getNumDimIds());
// Calculate number of equalities/inequalities and columns required to
// initialize FlatAffineConstraints for 'dependenceDomain'.
unsigned numIneq = srcIterationDomainContext.domain.getNumInequalities() +
dstIterationDomainContext.domain.getNumInequalities();
unsigned numIneq =
srcDomain.getNumInequalities() + dstDomain.getNumInequalities();
AffineMap srcMap = srcAccessMap.getAffineMap();
assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults());
unsigned numEq = srcMap.getNumResults();
unsigned numDims = valuePosMap.getNumDims();
unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds();
unsigned numSymbols = valuePosMap.getNumSymbols();
unsigned numIds = numDims + numSymbols;
unsigned numCols = numIds + 1;
@ -1254,11 +1211,11 @@ bool mlir::checkMemrefAccessDependence(
return true;
// Add 'src' happens before 'dst' ordering constraints.
addOrderingConstraints(srcIterationDomainContext, dstIterationDomainContext,
valuePosMap, loopDepth, dependenceConstraints);
addOrderingConstraints(srcDomain, dstDomain, loopDepth,
dependenceConstraints);
// Add src and dst domain constraints.
addDomainConstraints(srcIterationDomainContext, dstIterationDomainContext,
valuePosMap, dependenceConstraints);
addDomainConstraints(srcDomain, dstDomain, valuePosMap,
dependenceConstraints);
// Return false if the solution space is empty: no dependence.
if (dependenceConstraints->isEmpty()) {
@ -1266,9 +1223,8 @@ bool mlir::checkMemrefAccessDependence(
}
// Compute dependence direction vector and return true.
if (dependenceComponents != nullptr) {
computeDirectionVector(srcIterationDomainContext, dstIterationDomainContext,
loopDepth, dependenceConstraints,
dependenceComponents);
computeDirectionVector(srcDomain, dstDomain, loopDepth,
dependenceConstraints, dependenceComponents);
}
return true;
}

View File

@ -1250,9 +1250,7 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
numSymbols = newSymbolCount;
}
// TODO(andydavis, bondhugula) AFFINE REFACTOR: merge with loop bounds
// code in dependence analysis.
bool FlatAffineConstraints::addBoundsFromForStmt(const ForStmt &forStmt) {
bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
unsigned pos;
// Pre-condition for this method.
if (!findId(*cast<MLValue>(&forStmt), &pos)) {
@ -1260,14 +1258,16 @@ bool FlatAffineConstraints::addBoundsFromForStmt(const ForStmt &forStmt) {
return false;
}
if (forStmt.getStep() != 1)
LLVM_DEBUG(llvm::dbgs()
<< "Domain conservative: non-unit stride not handled\n");
// Adds a lower or upper bound when the bounds aren't constant.
auto addLowerOrUpperBound = [&](bool lower) -> bool {
auto operands = lower ? forStmt.getLowerBoundOperands()
: forStmt.getUpperBoundOperands();
for (const auto &operand : operands) {
unsigned loc;
// TODO(andydavis, bondhugula) AFFINE REFACTOR: merge with loop bounds
// code in dependence analysis.
if (!findId(*operand, &loc)) {
if (operand->isValidSymbol()) {
addSymbolId(getNumSymbolIds(), const_cast<MLValue *>(operand));

View File

@ -70,12 +70,16 @@ bool mlir::dominates(const Statement &a, const Statement &b) {
/// Populates 'loops' with IVs of the loops surrounding 'stmt' ordered from
/// the outermost 'for' statement to the innermost one.
// TODO(mlir-team): skip over 'if' statements.
void mlir::getLoopIVs(const Statement &stmt,
SmallVectorImpl<ForStmt *> *loops) {
auto *currStmt = stmt.getParentStmt();
while (currStmt != nullptr && isa<ForStmt>(currStmt)) {
loops->push_back(dyn_cast<ForStmt>(currStmt));
ForStmt *currForStmt;
// Traverse up the hierarchy collecing all 'for' statement while skipping over
// 'if' statements.
while (currStmt && ((currForStmt = dyn_cast<ForStmt>(currStmt)) ||
isa<IfStmt>(currStmt))) {
if (currForStmt)
loops->push_back(currForStmt);
currStmt = currStmt->getParentStmt();
}
std::reverse(loops->begin(), loops->end());
@ -190,7 +194,9 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
if (auto *loop = dyn_cast<ForStmt>(accessValueMap.getOperand(i))) {
// Note that regionCst can now have more dimensions than accessMap if the
// bounds expressions involve outer loops or other symbols.
if (!regionCst->addBoundsFromForStmt(*loop))
// TODO(bondhugula): rewrite this to use getStmtIndexSet; this way
// conditionals will be handled when the latter supports it.
if (!regionCst->addForStmtDomain(*loop))
return false;
} else {
// Has to be a valid symbol.

View File

@ -191,8 +191,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
SmallVector<MLValue *, 6> origLoopIVs(band.begin(), band.end());
SmallVector<Optional<MLValue *>, 6> ids(band.begin(), band.end());
FlatAffineConstraints cst(width, /*numSymbols=*/0, /*numLocals=*/0, ids);
addIndexSet(origLoopIVs, &cst);
FlatAffineConstraints cst;
getIndexSet(band, &cst);
if (!cst.isHyperRectangular(0, width)) {
rootForStmt->emitError("tiled code generation unimplemented for the"

View File

@ -197,21 +197,21 @@ mlfunc @store_range_load_after_range() {
}
// -----
// CHECK-LABEL: mlfunc @store_load_func_symbol(%arg0 : index) {
mlfunc @store_load_func_symbol(%arg0 : index) {
// CHECK-LABEL: mlfunc @store_load_func_symbol(%arg0 : index, %arg1 : index) {
mlfunc @store_load_func_symbol(%arg0 : index, %arg1 : index) {
%m = alloc() : memref<100xf32>
%c7 = constant 7.0 : f32
%c10 = constant 10 : index
for %i0 = 0 to 10 {
for %i0 = 0 to %arg1 {
%a0 = affine_apply (d0) -> (d0) (%arg0)
store %c7, %m[%a0] : memref<100xf32>
// expected-note@-1 {{dependence from 0 to 0 at depth 1 = [1, 9]}}
// expected-note@-1 {{dependence from 0 to 0 at depth 1 = [1, +inf]}}
// expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}}
// expected-note@-3 {{dependence from 0 to 1 at depth 1 = [1, 9]}}
// expected-note@-3 {{dependence from 0 to 1 at depth 1 = [1, +inf]}}
// expected-note@-4 {{dependence from 0 to 1 at depth 2 = true}}
%a1 = affine_apply (d0) -> (d0) (%arg0)
%v0 = load %m[%a1] : memref<100xf32>
// expected-note@-1 {{dependence from 1 to 0 at depth 1 = [1, 9]}}
// expected-note@-1 {{dependence from 1 to 0 at depth 1 = [1, +inf]}}
// expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}}
// expected-note@-3 {{dependence from 1 to 1 at depth 1 = false}}
// expected-note@-4 {{dependence from 1 to 1 at depth 2 = false}}
@ -511,12 +511,12 @@ mlfunc @dependence_cycle() {
}
// -----
// CHECK-LABEL: mlfunc @negative_and_positive_direction_vectors() {
mlfunc @negative_and_positive_direction_vectors() {
// CHECK-LABEL: mlfunc @negative_and_positive_direction_vectors(%arg0 : index, %arg1 : index) {
mlfunc @negative_and_positive_direction_vectors(%arg0 : index, %arg1 : index) {
%m = alloc() : memref<10x10xf32>
%c7 = constant 7.0 : f32
for %i0 = 0 to 10 {
for %i1 = 0 to 10 {
for %i0 = 0 to %arg0 {
for %i1 = 0 to %arg1 {
%a0 = affine_apply (d0, d1) -> (d0 - 1, d1 + 1) (%i0, %i1)
%v0 = load %m[%a0#0, %a0#1] : memref<10x10xf32>
// expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}}