Extend/complete dependence tester to utilize local var info.

- extend/complete dependence tester to utilize local var info while adding
  access function equality constraints; one more step closer to get slicing
  based fusion working in the general case of affine_apply's involving mod's/div's.
- update test case to reflect more accurate dependence information; remove
  inaccurate comment on test case mod_deps.
- fix a minor "bug" in equality addition in addMemRefAccessConstraints (doesn't
  affect correctness, but the fixed version is more intuitive).
- some more surrounding code clean up
- move simplifyAffineExpr out of anonymous AffineExprFlattener class - the
  latter has state, and the former should reside outside.

PiperOrigin-RevId: 227175600
This commit is contained in:
Uday Bondhugula 2018-12-28 15:34:07 -08:00 committed by jpienaar
parent 315a466aed
commit b1d9cc4d1e
3 changed files with 73 additions and 31 deletions

View File

@ -247,22 +247,6 @@ public:
eq[getConstantIndex()] = expr.getValue();
}
// Simplify the affine expression by flattening it and reconstructing it.
AffineExpr simplifyAffineExpr(AffineExpr expr) {
// TODO(bondhugula): only pure affine for now. The simplification here can
// be extended to semi-affine maps in the future.
if (!expr.isPureAffine())
return expr;
walkPostOrder(expr);
ArrayRef<int64_t> flattenedExpr = operandExprStack.back();
auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols,
localExprs, expr.getContext());
operandExprStack.pop_back();
assert(operandExprStack.empty());
return simplifiedExpr;
}
private:
void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil) {
assert(operandExprStack.size() >= 2);
@ -356,10 +340,23 @@ private:
} // end anonymous namespace
/// Simplify the affine expression by flattening it and reconstructing it.
AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
unsigned numSymbols) {
// TODO(bondhugula): only pure affine for now. The simplification here can
// be extended to semi-affine maps in the future.
if (!expr.isPureAffine())
return expr;
AffineExprFlattener flattener(numDims, numSymbols, expr.getContext());
return flattener.simplifyAffineExpr(expr);
flattener.walkPostOrder(expr);
ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols,
flattener.localExprs, expr.getContext());
flattener.operandExprStack.pop_back();
assert(flattener.operandExprStack.empty());
return simplifiedExpr;
}
/// Returns the AffineExpr that results from substituting `exprs[i]` into `e`
@ -416,6 +413,7 @@ static bool getFlattenedAffineExprs(
return true;
}
flattenedExprs->clear();
flattenedExprs->reserve(exprs.size());
AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext());
@ -428,6 +426,7 @@ static bool getFlattenedAffineExprs(
flattener.walkPostOrder(expr);
}
assert(flattener.operandExprStack.size() == exprs.size());
flattenedExprs->insert(flattenedExprs->end(),
flattener.operandExprStack.begin(),
flattener.operandExprStack.end());
@ -766,11 +765,15 @@ static void addDomainConstraints(const FlatAffineConstraints &srcDomain,
//
// Returns false if any AffineExpr cannot be flattened (due to it being
// semi-affine). Returns true otherwise.
// TODO(bondhugula): assumes that dependenceDomain doesn't have local
// variables already. Fix this soon.
static bool
addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
const AffineValueMap &dstAccessMap,
const ValuePositionMap &valuePosMap,
FlatAffineConstraints *dependenceDomain) {
if (dependenceDomain->getNumLocalIds() != 0)
return false;
AffineMap srcMap = srcAccessMap.getAffineMap();
AffineMap dstMap = dstAccessMap.getAffineMap();
assert(srcMap.getNumResults() == dstMap.getNumResults());
@ -826,7 +829,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
// Local terms.
for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
eq[numDims + numSymbols + numSrcLocalIds + j] =
destFlatExpr[dstNumIds + j];
-destFlatExpr[dstNumIds + j];
// Set constant term.
eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1];
@ -856,8 +859,45 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
// Add equality constraints for any dst symbols defined by constant ops.
addEqForConstOperands(dstOperands);
// TODO(b/122081337): add srcLocalVarCst, destLocalVarCst to the dependence
// domain.
// By construction (see flattener), local var constraints will not have any
// equalities.
assert(srcLocalVarCst.getNumEqualities() == 0 &&
destLocalVarCst.getNumEqualities() == 0);
// Add inequalities from srcLocalVarCst and destLocalVarCst into the
// dependence domain.
SmallVector<int64_t, 8> ineq(dependenceDomain->getNumCols());
for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) {
std::fill(ineq.begin(), ineq.end(), 0);
// Set identifier coefficients from src local var constraints.
for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] =
srcLocalVarCst.atIneq(r, j);
// Local terms.
for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
ineq[numDims + numSymbols + j] = srcLocalVarCst.atIneq(r, srcNumIds + j);
// Set constant term.
ineq[ineq.size() - 1] =
srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1);
dependenceDomain->addInequality(ineq);
}
for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) {
std::fill(ineq.begin(), ineq.end(), 0);
// Set identifier coefficients from dest local var constraints.
for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] =
destLocalVarCst.atIneq(r, j);
// Local terms.
for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
ineq[numDims + numSymbols + numSrcLocalIds + j] =
destLocalVarCst.atIneq(r, dstNumIds + j);
// Set constant term.
ineq[ineq.size() - 1] =
destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1);
dependenceDomain->addInequality(ineq);
}
return true;
}

View File

@ -1102,13 +1102,15 @@ AffineMap FlatAffineConstraints::toAffineMapFromEq(
unsigned idx, unsigned pos, MLIRContext *context,
SmallVectorImpl<unsigned> *nonZeroDimIds,
SmallVectorImpl<unsigned> *nonZeroSymbolIds) {
assert(getNumLocalIds() == 0);
assert(idx < getNumEqualities());
assert(getNumLocalIds() == 0 && "local ids not supported");
assert(idx < getNumEqualities() && "invalid equality position");
int64_t v = atEq(idx, pos);
// Return if coefficient at (idx, pos) is zero or does not divide constant.
if (v == 0 || (atEq(idx, getNumIds()) % v != 0))
return AffineMap::Null();
// Check that coefficient at 'pos' divides all other coefficient in row 'idx'.
// Check that coefficient at 'pos' divides all other coefficients in row
// 'idx'.
for (unsigned j = 0, e = getNumIds(); j < e; ++j) {
if (j != pos && (atEq(idx, j) % v != 0))
return AffineMap::Null();
@ -1441,7 +1443,7 @@ void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) {
// i + s0 + 16 <= d0 <= i + s0 + 31, returns 16.
Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
unsigned pos, SmallVectorImpl<int64_t> *lb) const {
assert(pos < getNumDimIds() && "Invalid position");
assert(pos < getNumDimIds() && "Invalid identifier position");
assert(getNumLocalIds() == 0);
// TODO(bondhugula): eliminate all remaining dimensional identifiers (other

View File

@ -625,9 +625,8 @@ mlfunc @mod_deps() {
%c7 = constant 7.0 : f32
for %i0 = 0 to 10 {
%a0 = affine_apply (d0) -> (d0 mod 2) (%i0)
// Results are conservative here since constraint information after
// flattening isn't being completely added. Will be done in the next CL.
// The third and the fifth dependence below shouldn't have existed.
// Results are conservative here since we currently don't have a way to
// represent strided sets in FlatAffineConstraints.
%v0 = load %m[%a0] : memref<100xf32>
// expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}}
// expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}}
@ -688,14 +687,15 @@ mlfunc @mod_div_3d() {
for %i1 = 0 to 8 {
for %i2 = 0 to 8 {
%idx = affine_apply (d0, d1, d2) -> (d0 floordiv 4, d1 mod 2, d2 floordiv 4) (%i0, %i1, %i2)
// Dependences below are conservative due to TODO(b/122081337).
store %c0, %M[%idx#0, %idx#1, %idx#2] : memref<2 x 2 x 2 x i32>
// expected-note@-1 {{dependence from 0 to 0 at depth 1 = [1, 7][-7, 7][-7, 7]}}
// expected-note@-2 {{dependence from 0 to 0 at depth 2 = [0, 0][2, 7][-7, 7]}}
// expected-note@-3 {{dependence from 0 to 0 at depth 3 = [0, 0][0, 0][1, 7]}}
// expected-note@-1 {{dependence from 0 to 0 at depth 1 = [1, 3][-7, 7][-3, 3]}}
// expected-note@-2 {{dependence from 0 to 0 at depth 2 = [0, 0][2, 7][-3, 3]}}
// expected-note@-3 {{dependence from 0 to 0 at depth 3 = [0, 0][0, 0][1, 3]}}
// expected-note@-4 {{dependence from 0 to 0 at depth 4 = false}}
}
}
}
return
}
// TODO(bondhugula): add more test cases exercising mod/div affine_apply's.