Fix getFullMemRefAsRegion() and FlatAffineConstraints::reset

PiperOrigin-RevId: 231426734
This commit is contained in:
Uday Bondhugula 2019-01-29 10:24:30 -08:00 committed by jpienaar
parent c224a518f5
commit c0e9e5eb07
2 changed files with 22 additions and 12 deletions

View File

@ -248,17 +248,17 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities,
numDims = newNumDims;
numSymbols = newNumSymbols;
numIds = numDims + numSymbols + newNumLocals;
assert(idArgs.empty() || idArgs.size() == numIds);
clearConstraints();
if (numReservedEqualities >= 1)
equalities.reserve(newNumReservedCols * numReservedEqualities);
if (numReservedInequalities >= 1)
inequalities.reserve(newNumReservedCols * numReservedInequalities);
ids.clear();
if (idArgs.empty()) {
ids.resize(numIds, None);
} else {
ids.reserve(idArgs.size());
ids.append(idArgs.begin(), idArgs.end());
ids.assign(idArgs.begin(), idArgs.end());
}
}
@ -2078,6 +2078,10 @@ static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
// lower bounds and the max of the upper bounds along each of the dimensions.
bool FlatAffineConstraints::unionBoundingBox(
const FlatAffineConstraints &other) {
assert(other.getNumDimIds() == numDims);
assert(other.getNumSymbolIds() == getNumSymbolIds());
assert(other.getNumLocalIds() == 0);
assert(getNumLocalIds() == 0);
std::vector<SmallVector<int64_t, 8>> boundingLbs;
std::vector<SmallVector<int64_t, 8>> boundingUbs;
boundingLbs.reserve(2 * getNumDimIds());

View File

@ -147,8 +147,11 @@ static void getMultiLevelStrides(const MemRefRegion &region,
}
/// Construct the memref region to just include the entire memref. Returns false
/// dynamic shaped memref's for now.
static bool getFullMemRefAsRegion(OperationInst *opInst, unsigned numSymbols,
/// dynamic shaped memref's for now. `numParamLoopIVs` is the number of
/// enclosing loop IVs of opInst (starting from the outermost) that the region
/// is parametric on.
static bool getFullMemRefAsRegion(OperationInst *opInst,
unsigned numParamLoopIVs,
MemRefRegion *region) {
unsigned rank;
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
@ -167,13 +170,16 @@ static bool getFullMemRefAsRegion(OperationInst *opInst, unsigned numSymbols,
if (memRefType.getNumDynamicDims() > 0)
return false;
SmallVector<ForInst *, 4> ivs;
getLoopIVs(*opInst, &ivs);
auto *regionCst = region->getConstraints();
SmallVector<Value *, 8> symbols = extractForInductionVars(ivs);
regionCst->reset(rank, numSymbols, 0, symbols);
// Just get the first numSymbols IVs, which the memref region is parametric
// on.
SmallVector<ForInst *, 4> ivs;
getLoopIVs(*opInst, &ivs);
ivs.resize(numParamLoopIVs);
SmallVector<Value *, 4> symbols = extractForInductionVars(ivs);
regionCst->reset(rank, numParamLoopIVs, 0);
regionCst->setIdValues(rank, rank + numParamLoopIVs, symbols);
// Memref dim sizes provide the bounds.
for (unsigned d = 0; d < rank; d++) {
@ -466,7 +472,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
// Perform a union with the existing region.
if (!(*it).second->unionBoundingBox(*region)) {
LLVM_DEBUG(llvm::dbgs()
<< "Memory region bounding box failed"
<< "Memory region bounding box failed; "
"over-approximating to the entire memref\n");
if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) {
LLVM_DEBUG(forInst->emitError(
@ -479,7 +485,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
bool existsInRead = updateRegion(readRegions);
bool existsInWrite = updateRegion(writeRegions);
// Finally add it to the region.
// Finally add it to the region list.
if (region->isWrite() && !existsInWrite) {
writeRegions[region->memref] = std::move(region);
} else if (!region->isWrite() && !existsInRead) {