forked from OSchip/llvm-project
Fix getFullMemRefAsRegion() and FlatAffineConstraints::reset
PiperOrigin-RevId: 231426734
This commit is contained in:
parent
c224a518f5
commit
c0e9e5eb07
|
@ -248,17 +248,17 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities,
|
|||
numDims = newNumDims;
|
||||
numSymbols = newNumSymbols;
|
||||
numIds = numDims + numSymbols + newNumLocals;
|
||||
assert(idArgs.empty() || idArgs.size() == numIds);
|
||||
|
||||
clearConstraints();
|
||||
if (numReservedEqualities >= 1)
|
||||
equalities.reserve(newNumReservedCols * numReservedEqualities);
|
||||
if (numReservedInequalities >= 1)
|
||||
inequalities.reserve(newNumReservedCols * numReservedInequalities);
|
||||
ids.clear();
|
||||
if (idArgs.empty()) {
|
||||
ids.resize(numIds, None);
|
||||
} else {
|
||||
ids.reserve(idArgs.size());
|
||||
ids.append(idArgs.begin(), idArgs.end());
|
||||
ids.assign(idArgs.begin(), idArgs.end());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2078,6 +2078,10 @@ static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
|
|||
// lower bounds and the max of the upper bounds along each of the dimensions.
|
||||
bool FlatAffineConstraints::unionBoundingBox(
|
||||
const FlatAffineConstraints &other) {
|
||||
assert(other.getNumDimIds() == numDims);
|
||||
assert(other.getNumSymbolIds() == getNumSymbolIds());
|
||||
assert(other.getNumLocalIds() == 0);
|
||||
assert(getNumLocalIds() == 0);
|
||||
std::vector<SmallVector<int64_t, 8>> boundingLbs;
|
||||
std::vector<SmallVector<int64_t, 8>> boundingUbs;
|
||||
boundingLbs.reserve(2 * getNumDimIds());
|
||||
|
|
|
@ -147,8 +147,11 @@ static void getMultiLevelStrides(const MemRefRegion ®ion,
|
|||
}
|
||||
|
||||
/// Construct the memref region to just include the entire memref. Returns false
|
||||
/// dynamic shaped memref's for now.
|
||||
static bool getFullMemRefAsRegion(OperationInst *opInst, unsigned numSymbols,
|
||||
/// dynamic shaped memref's for now. `numParamLoopIVs` is the number of
|
||||
/// enclosing loop IVs of opInst (starting from the outermost) that the region
|
||||
/// is parametric on.
|
||||
static bool getFullMemRefAsRegion(OperationInst *opInst,
|
||||
unsigned numParamLoopIVs,
|
||||
MemRefRegion *region) {
|
||||
unsigned rank;
|
||||
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
|
||||
|
@ -167,13 +170,16 @@ static bool getFullMemRefAsRegion(OperationInst *opInst, unsigned numSymbols,
|
|||
if (memRefType.getNumDynamicDims() > 0)
|
||||
return false;
|
||||
|
||||
SmallVector<ForInst *, 4> ivs;
|
||||
getLoopIVs(*opInst, &ivs);
|
||||
|
||||
auto *regionCst = region->getConstraints();
|
||||
|
||||
SmallVector<Value *, 8> symbols = extractForInductionVars(ivs);
|
||||
regionCst->reset(rank, numSymbols, 0, symbols);
|
||||
// Just get the first numSymbols IVs, which the memref region is parametric
|
||||
// on.
|
||||
SmallVector<ForInst *, 4> ivs;
|
||||
getLoopIVs(*opInst, &ivs);
|
||||
ivs.resize(numParamLoopIVs);
|
||||
SmallVector<Value *, 4> symbols = extractForInductionVars(ivs);
|
||||
regionCst->reset(rank, numParamLoopIVs, 0);
|
||||
regionCst->setIdValues(rank, rank + numParamLoopIVs, symbols);
|
||||
|
||||
// Memref dim sizes provide the bounds.
|
||||
for (unsigned d = 0; d < rank; d++) {
|
||||
|
@ -466,7 +472,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
|
|||
// Perform a union with the existing region.
|
||||
if (!(*it).second->unionBoundingBox(*region)) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Memory region bounding box failed"
|
||||
<< "Memory region bounding box failed; "
|
||||
"over-approximating to the entire memref\n");
|
||||
if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) {
|
||||
LLVM_DEBUG(forInst->emitError(
|
||||
|
@ -479,7 +485,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
|
|||
bool existsInRead = updateRegion(readRegions);
|
||||
bool existsInWrite = updateRegion(writeRegions);
|
||||
|
||||
// Finally add it to the region.
|
||||
// Finally add it to the region list.
|
||||
if (region->isWrite() && !existsInWrite) {
|
||||
writeRegions[region->memref] = std::move(region);
|
||||
} else if (!region->isWrite() && !existsInRead) {
|
||||
|
|
Loading…
Reference in New Issue