Update / complete a TODO for addBoundsForForStmt

- when adding constraints from a 'for' stmt into FlatAffineConstraints,
  correctly add bound operands of the 'for' stmt as a dimensional identifier or
  a symbolic identifier depending on whether the bound operand is a valid
  MLFunction symbol
- update test case to exercise this.

PiperOrigin-RevId: 225988511
This commit is contained in:
Uday Bondhugula 2018-12-18 06:43:20 -08:00 committed by jpienaar
parent 49c81ebcb0
commit 1d72f2e47e
4 changed files with 35 additions and 16 deletions

View File

@ -417,9 +417,10 @@ public:
/// right identifier is first looked up using forStmt's MLValue. Returns
/// false for the yet unimplemented/unsupported cases, and true if the
/// information is succesfully added. Asserts if the MLValue corresponding to
/// the 'for' statement isn't found in the system. Any new identifiers that
/// may need to be added due to the bound operands of the 'for' statement are
/// added as trailing dimensional identifiers (just before symbolic ones).
/// the 'for' statement isn't found in the constaint system. Any new
/// identifiers that are found in the bound operands of the 'for' statement
/// are added as trailing identifiers (either dimensional or symbolic
/// depending on whether the operand is a valid MLFunction symbol).
bool addBoundsFromForStmt(const ForStmt &forStmt);
/// Adds an upper bound expression for the specified expression.
@ -446,7 +447,7 @@ public:
// the kind of identifier. 'id' is the MLValue corresponding to the
// identifier that can optionally be provided.
void addDimId(unsigned pos, MLValue *id = nullptr);
void addSymbolId(unsigned pos);
void addSymbolId(unsigned pos, MLValue *id = nullptr);
void addLocalId(unsigned pos);
void addId(IdKind kind, unsigned pos, MLValue *id = nullptr);

View File

@ -601,8 +601,8 @@ void FlatAffineConstraints::addDimId(unsigned pos, MLValue *id) {
addId(IdKind::Dimension, pos, id);
}
void FlatAffineConstraints::addSymbolId(unsigned pos) {
addId(IdKind::Symbol, pos);
void FlatAffineConstraints::addSymbolId(unsigned pos, MLValue *id) {
addId(IdKind::Symbol, pos, id);
}
/// Adds a dimensional identifier. The added column is initialized to
@ -967,7 +967,7 @@ void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) {
// Checks for emptiness of the set by eliminating identifiers successively and
// using the GCD test (on all equality constraints) and checking for trivially
// invalid constraints. Returns 'true' if the constaint system is found to be
// invalid constraints. Returns 'true' if the constraint system is found to be
// empty; false otherwise.
bool FlatAffineConstraints::isEmpty() const {
if (isEmptyByGCDTest() || hasInvalidConstraint())
@ -1264,18 +1264,32 @@ bool FlatAffineConstraints::addBoundsFromForStmt(const ForStmt &forStmt) {
auto addLowerOrUpperBound = [&](bool lower) -> bool {
auto operands = lower ? forStmt.getLowerBoundOperands()
: forStmt.getUpperBoundOperands();
SmallVector<unsigned, 8> positions;
for (const auto &operand : operands) {
unsigned loc;
// TODO(andydavis, bondhugula) AFFINE REFACTOR: merge with loop bounds
// code in dependence analysis.
if (!findId(*operand, &loc)) {
// Adding this as a dimensional identifier even if this operand was a
// symblic operand to the bound map. TODO(mlir-team): if needed, check
// which one it is and add as dimensional or a symbolic one.
addDimId(getNumDimIds(), const_cast<MLValue *>(operand));
loc = getNumDimIds() - 1;
if (operand->isValidSymbol()) {
addSymbolId(getNumSymbolIds(), const_cast<MLValue *>(operand));
loc = getNumDimIds() + getNumSymbolIds() - 1;
// Check if the symbol is a constant.
if (auto *opStmt = operand->getDefiningStmt()) {
if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
setIdToConstant(*operand, constOp->getValue());
}
}
} else {
addDimId(getNumDimIds(), const_cast<MLValue *>(operand));
loc = getNumDimIds() - 1;
}
}
}
// Record positions of the operands in the constraint system.
SmallVector<unsigned, 8> positions;
for (const auto &operand : operands) {
unsigned loc;
if (!findId(*operand, &loc))
assert(0 && "expected to be found");
positions.push_back(loc);
}

View File

@ -124,8 +124,8 @@ Optional<int64_t> MemRefRegion::getBoundingConstantSizeAndShape(
/// Computes the memory region accessed by this memref with the region
/// represented as constraints symbolic/parameteric in 'loopDepth' loops
/// surrounding opStmt. Returns false if this fails due to yet unimplemented
/// cases.
/// surrounding opStmt and any additional MLFunction symbols. Returns false if
/// this fails due to yet unimplemented cases.
// For example, the memref region for this load operation at loopDepth = 1 will
// be as below:
//
@ -176,6 +176,9 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
forwardSubstituteReachableOps(&accessValueMap);
AffineMap accessMap = accessValueMap.getAffineMap();
// We'll first associate the dims and symbols of the access map to the dims
// and symbols resp. of regionCst. This will change below once regionCst is
// fully constructed out.
regionCst->reset(accessMap.getNumDims(), accessMap.getNumSymbols(), 0,
accessValueMap.getOperands());

View File

@ -190,13 +190,14 @@ mlfunc @loop_nest_tiled() -> memref<256x1024xf32> {
// CHECK-LABEL: mlfunc @dma_constant_dim_access
mlfunc @dma_constant_dim_access(%A : memref<100x100xf32>) {
%one = constant 1 : index
%N = constant 100 : index
// CHECK: %0 = alloc() : memref<1x100xf32, 1>
// CHECK-NEXT: %1 = alloc() : memref<1xi32>
// No strided DMA needed here.
// CHECK: dma_start %arg0[%c1, %c0], %0[%c0, %c0], %c100, %1[%c0] : memref<100x100xf32>, memref<1x100xf32, 1>,
// CHECK-NEXT: dma_wait %1[%c0], %c100 : memref<1xi32>
for %i = 0 to 100 {
for %j = 0 to 100 {
for %j = 0 to ()[s0] -> (s0) ()[%N] {
// CHECK: %2 = affine_apply [[MAP_MINUS_ONE]](%c1, %i1)
// CHECK-NEXT: %3 = load %0[%2#0, %2#1] : memref<1x100xf32, 1>
load %A[%one, %j] : memref<100 x 100 x f32>