Refactor/update memref-dep-check's addMemRefAccessConstraints and

addDomainConstraints; add support for mod/div for dependence testing.

- add support for mod/div expressions in dependence analysis
- refactor addMemRefAccessConstraints to use getFlattenedAffineExprs (instead
  of getFlattenedAffineExpr); update addDomainConstraints.
- rename AffineExprFlattener::cst -> localVarCst

PiperOrigin-RevId: 225933306
This commit is contained in:
Uday Bondhugula 2018-12-17 20:16:37 -08:00 committed by jpienaar
parent 4dbd94b543
commit 20531932f4
3 changed files with 111 additions and 75 deletions

View File

@ -120,7 +120,7 @@ namespace {
// is more efficient than creating a new flattener for each expression since // is more efficient than creating a new flattener for each expression since
// common idenical div and mod expressions appearing across different // common idenical div and mod expressions appearing across different
// expressions are mapped to the local identifier (same column position in // expressions are mapped to the local identifier (same column position in
// 'cst'). // 'localVarCst').
struct AffineExprFlattener : public AffineExprVisitor<AffineExprFlattener> { struct AffineExprFlattener : public AffineExprVisitor<AffineExprFlattener> {
public: public:
// Flattend expression layout: [dims, symbols, locals, constant] // Flattend expression layout: [dims, symbols, locals, constant]
@ -129,9 +129,10 @@ public:
// will be, and linearize this to std::vector<int64_t> to prevent // will be, and linearize this to std::vector<int64_t> to prevent
// SmallVector moves on re-allocation. // SmallVector moves on re-allocation.
std::vector<SmallVector<int64_t, 32>> operandExprStack; std::vector<SmallVector<int64_t, 32>> operandExprStack;
// Constraints connecting newly introduced local variables to existing // Constraints connecting newly introduced local variables (for mod's and
// (dimensional and symbolic) ones. // div's) to existing (dimensional and symbolic) ones. These are always
FlatAffineConstraints cst; // inequalities.
FlatAffineConstraints localVarCst;
unsigned numDims; unsigned numDims;
unsigned numSymbols; unsigned numSymbols;
@ -153,7 +154,7 @@ public:
: numDims(numDims), numSymbols(numSymbols), numLocals(0), : numDims(numDims), numSymbols(numSymbols), numLocals(0),
context(context) { context(context) {
operandExprStack.reserve(8); operandExprStack.reserve(8);
cst.reset(numDims, numSymbols, numLocals); localVarCst.reset(numDims, numSymbols, numLocals);
} }
void visitMulExpr(AffineBinaryOpExpr expr) { void visitMulExpr(AffineBinaryOpExpr expr) {
@ -214,9 +215,9 @@ public:
if ((loc = findLocalId(floorDiv)) == -1) { if ((loc = findLocalId(floorDiv)) == -1) {
addLocalId(floorDiv); addLocalId(floorDiv);
lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst; lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
// Update cst: 0 <= expr1 - c * expr2 <= c - 1. // Update localVarCst: 0 <= expr1 - c * expr2 <= c - 1.
cst.addConstantLowerBound(lhs, 0); localVarCst.addConstantLowerBound(lhs, 0);
cst.addConstantUpperBound(lhs, rhsConst - 1); localVarCst.addConstantUpperBound(lhs, rhsConst - 1);
} else { } else {
// Reuse the existing local id. // Reuse the existing local id.
lhs[getLocalVarStartIndex() + loc] = -rhsConst; lhs[getLocalVarStartIndex() + loc] = -rhsConst;
@ -305,14 +306,14 @@ private:
bound[getLocalVarStartIndex() + numLocals - 1] = rhsConst; bound[getLocalVarStartIndex() + numLocals - 1] = rhsConst;
if (!isCeil) { if (!isCeil) {
// q = lhs floordiv c <=> c*q <= lhs <= c*q + c - 1. // q = lhs floordiv c <=> c*q <= lhs <= c*q + c - 1.
cst.addLowerBound(lhs, bound); localVarCst.addLowerBound(lhs, bound);
bound[bound.size() - 1] = rhsConst - 1; bound[bound.size() - 1] = rhsConst - 1;
cst.addUpperBound(lhs, bound); localVarCst.addUpperBound(lhs, bound);
} else { } else {
// q = lhs ceildiv c <=> c*q - (c - 1) <= lhs <= c*q. // q = lhs ceildiv c <=> c*q - (c - 1) <= lhs <= c*q.
cst.addUpperBound(lhs, bound); localVarCst.addUpperBound(lhs, bound);
bound[bound.size() - 1] = -(rhsConst - 1); bound[bound.size() - 1] = -(rhsConst - 1);
cst.addLowerBound(lhs, bound); localVarCst.addLowerBound(lhs, bound);
} }
} }
// Set the expression on stack to the local var introduced to capture the // Set the expression on stack to the local var introduced to capture the
@ -333,7 +334,7 @@ private:
} }
localExprs.push_back(localExpr); localExprs.push_back(localExpr);
numLocals++; numLocals++;
cst.addLocalId(cst.getNumLocalIds()); localVarCst.addLocalId(localVarCst.getNumLocalIds());
} }
int findLocalId(AffineExpr localExpr) { int findLocalId(AffineExpr localExpr) {
@ -409,9 +410,9 @@ AffineExpr mlir::composeWithUnboundedMap(AffineExpr e, AffineMap g) {
static bool getFlattenedAffineExprs( static bool getFlattenedAffineExprs(
ArrayRef<AffineExpr> exprs, unsigned numDims, unsigned numSymbols, ArrayRef<AffineExpr> exprs, unsigned numDims, unsigned numSymbols,
std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
FlatAffineConstraints *cst) { FlatAffineConstraints *localVarCst) {
if (exprs.empty()) { if (exprs.empty()) {
cst->reset(numDims, numSymbols); localVarCst->reset(numDims, numSymbols);
return true; return true;
} }
@ -435,8 +436,8 @@ static bool getFlattenedAffineExprs(
flattenedExprs->push_back(flattenedExpr); flattenedExprs->push_back(flattenedExpr);
flattener.operandExprStack.pop_back(); flattener.operandExprStack.pop_back();
} }
if (cst) if (localVarCst)
cst->clearAndCopyFrom(flattener.cst); localVarCst->clearAndCopyFrom(flattener.localVarCst);
return true; return true;
} }
@ -447,10 +448,10 @@ static bool getFlattenedAffineExprs(
bool mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, bool mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
unsigned numSymbols, unsigned numSymbols,
llvm::SmallVectorImpl<int64_t> *flattenedExpr, llvm::SmallVectorImpl<int64_t> *flattenedExpr,
FlatAffineConstraints *cst) { FlatAffineConstraints *localVarCst) {
std::vector<SmallVector<int64_t, 8>> flattenedExprs; std::vector<SmallVector<int64_t, 8>> flattenedExprs;
bool ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, bool ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
&flattenedExprs, cst); &flattenedExprs, localVarCst);
*flattenedExpr = flattenedExprs[0]; *flattenedExpr = flattenedExprs[0];
return ret; return ret;
} }
@ -460,24 +461,26 @@ bool mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
/// handled yet). /// handled yet).
bool mlir::getFlattenedAffineExprs( bool mlir::getFlattenedAffineExprs(
AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs, AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
FlatAffineConstraints *cst) { FlatAffineConstraints *localVarCst) {
if (map.getNumResults() == 0) { if (map.getNumResults() == 0) {
cst->reset(map.getNumDims(), map.getNumSymbols()); localVarCst->reset(map.getNumDims(), map.getNumSymbols());
return true; return true;
} }
return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(), return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
map.getNumSymbols(), flattenedExprs, cst); map.getNumSymbols(), flattenedExprs,
localVarCst);
} }
bool mlir::getFlattenedAffineExprs( bool mlir::getFlattenedAffineExprs(
IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs, IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
FlatAffineConstraints *cst) { FlatAffineConstraints *localVarCst) {
if (set.getNumConstraints() == 0) { if (set.getNumConstraints() == 0) {
cst->reset(set.getNumDims(), set.getNumSymbols()); localVarCst->reset(set.getNumDims(), set.getNumSymbols());
return true; return true;
} }
return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
set.getNumSymbols(), flattenedExprs, cst); set.getNumSymbols(), flattenedExprs,
localVarCst);
} }
/// Returns the sequence of AffineApplyOp OperationStmts operation in /// Returns the sequence of AffineApplyOp OperationStmts operation in
@ -760,12 +763,7 @@ static void addDomainConstraints(const IterationDomainContext &srcCtx,
unsigned dstNumSymbols = dstCtx.domain.getNumSymbolIds(); unsigned dstNumSymbols = dstCtx.domain.getNumSymbolIds();
unsigned dstNumIds = dstNumDims + dstNumSymbols; unsigned dstNumIds = dstNumDims + dstNumSymbols;
unsigned outputNumDims = dependenceDomain->getNumDimIds(); SmallVector<int64_t, 4> ineq(dependenceDomain->getNumCols());
unsigned outputNumSymbols = dependenceDomain->getNumSymbolIds();
unsigned outputNumIds = outputNumDims + outputNumSymbols;
SmallVector<int64_t, 4> ineq;
ineq.resize(outputNumIds + 1);
// Add inequalities from src domain. // Add inequalities from src domain.
for (unsigned i = 0; i < srcNumIneq; ++i) { for (unsigned i = 0; i < srcNumIneq; ++i) {
// Zero fill. // Zero fill.
@ -775,7 +773,7 @@ static void addDomainConstraints(const IterationDomainContext &srcCtx,
ineq[valuePosMap.getSrcDimOrSymPos(srcCtx.values[j])] = ineq[valuePosMap.getSrcDimOrSymPos(srcCtx.values[j])] =
srcCtx.domain.atIneq(i, j); srcCtx.domain.atIneq(i, j);
// Set constant term. // Set constant term.
ineq[outputNumIds] = srcCtx.domain.atIneq(i, srcNumIds); ineq[ineq.size() - 1] = srcCtx.domain.atIneq(i, srcNumIds);
// Add inequality constraint. // Add inequality constraint.
dependenceDomain->addInequality(ineq); dependenceDomain->addInequality(ineq);
} }
@ -788,7 +786,7 @@ static void addDomainConstraints(const IterationDomainContext &srcCtx,
ineq[valuePosMap.getDstDimOrSymPos(dstCtx.values[j])] = ineq[valuePosMap.getDstDimOrSymPos(dstCtx.values[j])] =
dstCtx.domain.atIneq(i, j); dstCtx.domain.atIneq(i, j);
// Set constant term. // Set constant term.
ineq[outputNumIds] = dstCtx.domain.atIneq(i, dstNumIds); ineq[ineq.size() - 1] = dstCtx.domain.atIneq(i, dstNumIds);
// Add inequality constraint. // Add inequality constraint.
dependenceDomain->addInequality(ineq); dependenceDomain->addInequality(ineq);
} }
@ -815,8 +813,8 @@ static void addDomainConstraints(const IterationDomainContext &srcCtx,
// a0 -c0 (a1 - c1) (a1 - c2) = 0 // a0 -c0 (a1 - c1) (a1 - c2) = 0
// b0 -f0 (b1 - f1) (b1 - f2) = 0 // b0 -f0 (b1 - f1) (b1 - f2) = 0
// //
// Returns false if any AffineExpr cannot be flattened (which will be removed // Returns false if any AffineExpr cannot be flattened (due to it being
// when mod/floor/ceil support is added). Returns true otherwise. // semi-affine). Returns true otherwise.
static bool static bool
addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
const AffineValueMap &dstAccessMap, const AffineValueMap &dstAccessMap,
@ -827,48 +825,58 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
assert(srcMap.getNumResults() == dstMap.getNumResults()); assert(srcMap.getNumResults() == dstMap.getNumResults());
unsigned numResults = srcMap.getNumResults(); unsigned numResults = srcMap.getNumResults();
unsigned srcNumDims = srcMap.getNumDims();
unsigned srcNumSymbols = srcMap.getNumSymbols();
unsigned srcNumIds = srcNumDims + srcNumSymbols;
ArrayRef<MLValue *> srcOperands = srcAccessMap.getOperands(); ArrayRef<MLValue *> srcOperands = srcAccessMap.getOperands();
unsigned dstNumDims = dstMap.getNumDims();
unsigned dstNumSymbols = dstMap.getNumSymbols();
unsigned dstNumIds = dstNumDims + dstNumSymbols;
ArrayRef<MLValue *> dstOperands = dstAccessMap.getOperands(); ArrayRef<MLValue *> dstOperands = dstAccessMap.getOperands();
unsigned outputNumDims = dependenceDomain->getNumDimIds(); std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
unsigned outputNumSymbols = dependenceDomain->getNumSymbolIds(); std::vector<SmallVector<int64_t, 8>> destFlatExprs;
unsigned outputNumIds = outputNumDims + outputNumSymbols; FlatAffineConstraints srcLocalVarCst, destLocalVarCst;
// Get flattened expressions for the source destination maps.
if (!getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst) ||
!getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst))
return false;
SmallVector<int64_t, 4> eq(outputNumIds + 1); unsigned numLocalIdsToAdd =
SmallVector<int64_t, 4> flattenedExpr; srcLocalVarCst.getNumLocalIds() + destLocalVarCst.getNumLocalIds();
for (unsigned i = 0; i < numLocalIdsToAdd; i++) {
dependenceDomain->addLocalId(dependenceDomain->getNumLocalIds());
}
unsigned numDims = dependenceDomain->getNumDimIds();
unsigned numSymbols = dependenceDomain->getNumSymbolIds();
unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds();
// Equality to add.
SmallVector<int64_t, 8> eq(dependenceDomain->getNumCols());
for (unsigned i = 0; i < numResults; ++i) { for (unsigned i = 0; i < numResults; ++i) {
// Zero fill. // Zero fill.
std::fill(eq.begin(), eq.end(), 0); std::fill(eq.begin(), eq.end(), 0);
// Get flattened AffineExpr for result 'i' from src access function.
auto srcExpr = srcMap.getResult(i);
flattenedExpr.clear();
if (!getFlattenedAffineExpr(srcExpr, srcNumDims, srcNumSymbols,
&flattenedExpr))
return false;
// Set identifier coefficients from src access function.
for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = flattenedExpr[j];
// Set constant term.
eq[outputNumIds] = flattenedExpr[srcNumIds];
// Get flattened AffineExpr for result 'i' from dst access function. // Flattened AffineExpr for src result 'i'.
auto dstExpr = dstMap.getResult(i); const auto &srcFlatExpr = srcFlatExprs[i];
flattenedExpr.clear(); // Set identifier coefficients from src access function.
if (!getFlattenedAffineExpr(dstExpr, dstNumDims, dstNumSymbols, unsigned j, e;
&flattenedExpr)) for (j = 0, e = srcOperands.size(); j < e; ++j)
return false; eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j];
// Local terms.
for (e = srcFlatExpr.size() - 1; j < e; j++) {
eq[numDims + numSymbols + j] = srcFlatExpr[j];
}
// Set constant term.
eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1];
// Flattened AffineExpr for dest result 'i'.
const auto &destFlatExpr = destFlatExprs[i];
// Set identifier coefficients from dst access function. // Set identifier coefficients from dst access function.
for (unsigned j = 0, e = dstOperands.size(); j < e; ++j) for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= flattenedExpr[j]; eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j];
// Local terms.
for (e = destFlatExpr.size() - 1; j < e; j++) {
eq[numDims + numSymbols + numSrcLocalIds + j] = destFlatExpr[j];
}
// Set constant term. // Set constant term.
eq[outputNumIds] -= flattenedExpr[dstNumIds]; eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1];
// Add equality constraint. // Add equality constraint.
dependenceDomain->addEquality(eq); dependenceDomain->addEquality(eq);
} }
@ -894,6 +902,9 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
addEqForConstOperands(srcOperands); addEqForConstOperands(srcOperands);
// Add equality constraints for any dst symbols defined by constant ops. // Add equality constraints for any dst symbols defined by constant ops.
addEqForConstOperands(dstOperands); addEqForConstOperands(dstOperands);
// TODO(bondhugula): add srcLocalVarCst, destLocalVarCst to the dependence
// domain.
return true; return true;
} }

View File

@ -518,13 +518,13 @@ FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
// Flatten expressions and add them to the constraint system. // Flatten expressions and add them to the constraint system.
std::vector<SmallVector<int64_t, 8>> flatExprs; std::vector<SmallVector<int64_t, 8>> flatExprs;
FlatAffineConstraints cst; FlatAffineConstraints localVarCst;
if (!getFlattenedAffineExprs(set, &flatExprs, &cst)) { if (!getFlattenedAffineExprs(set, &flatExprs, &localVarCst)) {
assert(false && "flattening unimplemented for semi-affine integer sets"); assert(false && "flattening unimplemented for semi-affine integer sets");
return; return;
} }
assert(flatExprs.size() == set.getNumConstraints()); assert(flatExprs.size() == set.getNumConstraints());
for (unsigned l = 0, e = cst.getNumLocalIds(); l < e; l++) { for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) {
addLocalId(getNumLocalIds()); addLocalId(getNumLocalIds());
} }
@ -538,7 +538,7 @@ FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
} }
} }
// Add the other constraints involving local id's from flattening. // Add the other constraints involving local id's from flattening.
append(cst); append(localVarCst);
} }
void FlatAffineConstraints::reset(unsigned numReservedInequalities, void FlatAffineConstraints::reset(unsigned numReservedInequalities,
@ -1282,13 +1282,13 @@ bool FlatAffineConstraints::addBoundsFromForStmt(const ForStmt &forStmt) {
auto boundMap = auto boundMap =
lower ? forStmt.getLowerBoundMap() : forStmt.getUpperBoundMap(); lower ? forStmt.getLowerBoundMap() : forStmt.getUpperBoundMap();
FlatAffineConstraints cst; FlatAffineConstraints localVarCst;
std::vector<SmallVector<int64_t, 8>> flatExprs; std::vector<SmallVector<int64_t, 8>> flatExprs;
if (!getFlattenedAffineExprs(boundMap, &flatExprs, &cst)) { if (!getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst)) {
LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n"); LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
return false; return false;
} }
if (cst.getNumLocalIds() > 0) { if (localVarCst.getNumLocalIds() > 0) {
LLVM_DEBUG(llvm::dbgs() LLVM_DEBUG(llvm::dbgs()
<< "loop bounds with mod/floordiv expr's not yet supported\n"); << "loop bounds with mod/floordiv expr's not yet supported\n");
return false; return false;

View File

@ -111,9 +111,9 @@ mlfunc @store_load_different_symbols(%arg0 : index, %arg1 : index) {
mlfunc @store_load_diff_element_affine_apply_const() { mlfunc @store_load_diff_element_affine_apply_const() {
%m = alloc() : memref<100xf32> %m = alloc() : memref<100xf32>
%c1 = constant 1 : index %c1 = constant 1 : index
%c7 = constant 7.0 : f32 %c8 = constant 8.0 : f32
%a0 = affine_apply (d0) -> (d0) (%c1) %a0 = affine_apply (d0) -> (d0) (%c1)
store %c7, %m[%a0] : memref<100xf32> store %c8, %m[%a0] : memref<100xf32>
// expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}} // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}}
// expected-note@-2 {{dependence from 0 to 1 at depth 1 = false}} // expected-note@-2 {{dependence from 0 to 1 at depth 1 = false}}
%a1 = affine_apply (d0) -> (d0 + 1) (%c1) %a1 = affine_apply (d0) -> (d0 + 1) (%c1)
@ -565,3 +565,28 @@ mlfunc @war_raw_waw_deps() {
} }
return return
} }
// -----
// CHECK-LABEL: mlfunc @mod_deps() {
mlfunc @mod_deps() {
%m = alloc() : memref<100xf32>
%c7 = constant 7.0 : f32
for %i0 = 0 to 10 {
%a0 = affine_apply (d0) -> (d0 mod 2) (%i0)
// Results are conservative here since constraint information after
// flattening isn't being completely added. Will be done in the next CL.
// The third and the fifth dependence below shouldn't have existed.
%v0 = load %m[%a0] : memref<100xf32>
// expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}}
// expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}}
// expected-note@-3 {{dependence from 0 to 1 at depth 1 = [1, 9]}}
// expected-note@-4 {{dependence from 0 to 1 at depth 2 = false}}
%a1 = affine_apply (d0) -> ( (d0 + 1) mod 2) (%i0)
store %c7, %m[%a1] : memref<100xf32>
// expected-note@-1 {{dependence from 1 to 0 at depth 1 = [1, 9]}}
// expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}}
// expected-note@-3 {{dependence from 1 to 1 at depth 1 = [2, 9]}}
// expected-note@-4 {{dependence from 1 to 1 at depth 2 = false}}
}
return
}