Updates to transformation/analysis passes/utilities. Update DMA generation pass

and getMemRefRegion() to work with specified loop depths; add support for
outgoing DMAs, store op's.

- add support for getMemRefRegion symbolic in outer loops - hence support for
  DMAs symbolic in outer surrounding loops.

- add DMA generation support for outgoing DMAs (store op's to lower memory
  space); extend getMemoryRegion to store op's. -memref-bound-check now works
  with store op's as well.

- fix dma-generate (references to the old memref in the dma_start op were also
  being replaced with the new buffer); we need replace all memref uses to work
  only on a subset of the uses - add a new optional argument for
  replaceAllMemRefUsesWith. update replaceAllMemRefUsesWith to take an optional
  'operation' argument to serve as a filter - if provided, only those uses that
  are dominated by the filter are replaced.

- Add missing print for attributes for dma_start, dma_wait op's.

- update the FlatAffineConstraints API

PiperOrigin-RevId: 221889223
This commit is contained in:
Uday Bondhugula 2018-11-16 20:12:06 -08:00 committed by jpienaar
parent 6b52ac3aa6
commit fff1efbaf5
14 changed files with 882 additions and 229 deletions

View File

@ -34,6 +34,7 @@ class AffineApplyOp;
class AffineBound; class AffineBound;
class AffineCondition; class AffineCondition;
class AffineMap; class AffineMap;
class ForStmt;
class IntegerSet; class IntegerSet;
class MLIRContext; class MLIRContext;
class MLValue; class MLValue;
@ -177,7 +178,6 @@ public:
ArrayRef<MLValue *> getOperands() const; ArrayRef<MLValue *> getOperands() const;
AffineMap getAffineMap() const; AffineMap getAffineMap() const;
private: private:
void forwardSubstitute(const AffineApplyOp &inputOp, void forwardSubstitute(const AffineApplyOp &inputOp,
ArrayRef<bool> inputResultsToSubstitute); ArrayRef<bool> inputResultsToSubstitute);
@ -244,13 +244,19 @@ public:
FlatAffineConstraints(unsigned numReservedInequalities, FlatAffineConstraints(unsigned numReservedInequalities,
unsigned numReservedEqualities, unsigned numReservedEqualities,
unsigned numReservedCols, unsigned numDims = 0, unsigned numReservedCols, unsigned numDims = 0,
unsigned numSymbols = 0, unsigned numLocals = 0) unsigned numSymbols = 0, unsigned numLocals = 0,
ArrayRef<Optional<MLValue *>> idArgs = {})
: numReservedCols(numReservedCols), numDims(numDims), : numReservedCols(numReservedCols), numDims(numDims),
numSymbols(numSymbols) { numSymbols(numSymbols) {
assert(numReservedCols >= numDims + numSymbols + 1); assert(numReservedCols >= numDims + numSymbols + 1);
equalities.reserve(numReservedCols * numReservedEqualities); equalities.reserve(numReservedCols * numReservedEqualities);
inequalities.reserve(numReservedCols * numReservedInequalities); inequalities.reserve(numReservedCols * numReservedInequalities);
numIds = numDims + numSymbols + numLocals; numIds = numDims + numSymbols + numLocals;
ids.reserve(numReservedCols);
if (idArgs.empty())
ids.resize(numIds, None);
else
ids.insert(ids.end(), idArgs.begin(), idArgs.end());
} }
/// Constructs a constraint system with the specified number of /// Constructs a constraint system with the specified number of
@ -261,6 +267,7 @@ public:
numSymbols(numSymbols) { numSymbols(numSymbols) {
assert(numReservedCols >= numDims + numSymbols + 1); assert(numReservedCols >= numDims + numSymbols + 1);
numIds = numDims + numSymbols + numLocals; numIds = numDims + numSymbols + numLocals;
ids.resize(numIds, None);
} }
explicit FlatAffineConstraints(const HyperRectangularSet &set); explicit FlatAffineConstraints(const HyperRectangularSet &set);
@ -290,10 +297,10 @@ public:
// Clears any existing data and reserves memory for the specified constraints. // Clears any existing data and reserves memory for the specified constraints.
void reset(unsigned numReservedInequalities, unsigned numReservedEqualities, void reset(unsigned numReservedInequalities, unsigned numReservedEqualities,
unsigned numReservedCols, unsigned numDims, unsigned numSymbols, unsigned numReservedCols, unsigned numDims, unsigned numSymbols,
unsigned numLocals = 0); unsigned numLocals = 0, ArrayRef<MLValue *> idArgs = {});
void reset(unsigned numDims = 0, unsigned numSymbols = 0, void reset(unsigned numDims = 0, unsigned numSymbols = 0,
unsigned numLocals = 0); unsigned numLocals = 0, ArrayRef<MLValue *> idArgs = {});
/// Appends constraints from 'other' into this. This is equivalent to an /// Appends constraints from 'other' into this. This is equivalent to an
/// intersection with no simplification of any sort attempted. /// intersection with no simplification of any sort attempted.
@ -396,6 +403,12 @@ public:
/// Adds a lower bound expression for the specified expression. /// Adds a lower bound expression for the specified expression.
void addLowerBound(ArrayRef<int64_t> expr, ArrayRef<int64_t> lb); void addLowerBound(ArrayRef<int64_t> expr, ArrayRef<int64_t> lb);
/// Adds constraints (lower and upper bounds) from the ForStmt into the
/// FlatAffineConstraints. 'forStmt's' MLValue is used to look up the right
/// identifier, and if it doesn't exist, a new one is added. Returns false for
/// the yet unimplemented/unsupported cases.
bool addBoundsFromForStmt(unsigned pos, ForStmt *forStmt);
/// Adds an upper bound expression for the specified expression. /// Adds an upper bound expression for the specified expression.
void addUpperBound(ArrayRef<int64_t> expr, ArrayRef<int64_t> ub); void addUpperBound(ArrayRef<int64_t> expr, ArrayRef<int64_t> ub);
@ -407,12 +420,17 @@ public:
/// Sets the identifier at the specified position to a constant. /// Sets the identifier at the specified position to a constant.
void setIdToConstant(unsigned pos, int64_t val); void setIdToConstant(unsigned pos, int64_t val);
/// Looks up the identifier with the specified MLValue. Returns false if not
/// found.
bool findId(const MLValue &operand, unsigned *pos);
// Add identifiers of the specified kind - specified positions are relative to // Add identifiers of the specified kind - specified positions are relative to
// the kind of identifier. // the kind of identifier. 'id' is the MLValue corresponding to the
void addDimId(unsigned pos); // identifier that can optionally be provided.
void addDimId(unsigned pos, MLValue *id = nullptr);
void addSymbolId(unsigned pos); void addSymbolId(unsigned pos);
void addLocalId(unsigned pos); void addLocalId(unsigned pos);
void addId(IdKind kind, unsigned pos); void addId(IdKind kind, unsigned pos, MLValue *id = nullptr);
/// Composes the affine value map with this FlatAffineConstrains, adding the /// Composes the affine value map with this FlatAffineConstrains, adding the
/// results of the map as dimensions at the specified position and with the /// results of the map as dimensions at the specified position and with the
@ -435,6 +453,9 @@ public:
// value to mark exactness for example. // value to mark exactness for example.
void projectOut(unsigned pos, unsigned num); void projectOut(unsigned pos, unsigned num);
/// Projects out the identifier that is associate with MLValue *.
void projectOut(MLValue *id);
void removeId(IdKind idKind, unsigned pos); void removeId(IdKind idKind, unsigned pos);
void removeId(unsigned pos); void removeId(unsigned pos);
@ -453,19 +474,30 @@ public:
return numIds - numDims - numSymbols; return numIds - numDims - numSymbols;
} }
inline ArrayRef<Optional<MLValue *>> getIds() const {
return {ids.data(), ids.size()};
}
/// Clears this list of constraints and copies other into it. /// Clears this list of constraints and copies other into it.
void clearAndCopyFrom(const FlatAffineConstraints &other); void clearAndCopyFrom(const FlatAffineConstraints &other);
/// Returns the constant lower bound of the specified identifier (through a /// Returns the constant lower bound of the specified identifier (through a
/// scan through the constraints); returns None if the bound isn't trivially a /// scan through the constraints); returns None if the bound isn't trivially a
/// constant. /// constant.
Optional<int64_t> getConstantLowerBound(unsigned pos); Optional<int64_t> getConstantLowerBound(unsigned pos) const;
/// Returns the constant upper bound of the specified identifier (through a /// Returns the constant upper bound of the specified identifier (through a
/// scan through the constraints); returns None if the bound isn't trivially a /// scan through the constraints); returns None if the bound isn't trivially a
/// constant. Note that the upper bound for FlatAffineConstraints is /// constant. Note that the upper bound for FlatAffineConstraints is
/// inclusive. /// inclusive.
Optional<int64_t> getConstantUpperBound(unsigned pos); Optional<int64_t> getConstantUpperBound(unsigned pos) const;
/// Returns the extent (upper bound - lower bound) of the specified
/// identifier if it is found to be a constant; returns None if it's not a
/// constant. 'lbPosition' is set to the row position of the corresponding
/// lower bound.
Optional<int64_t> getConstantBoundDifference(unsigned pos,
unsigned *lbPosition) const;
// Returns the lower and upper bounds of the specified dimensions as // Returns the lower and upper bounds of the specified dimensions as
// AffineMap's. Returns false for the unimplemented cases for the moment. // AffineMap's. Returns false for the unimplemented cases for the moment.
@ -509,6 +541,12 @@ private:
/// Number of identifiers corresponding to symbols (unknown but constant for /// Number of identifiers corresponding to symbols (unknown but constant for
/// analysis). /// analysis).
unsigned numSymbols; unsigned numSymbols;
/// MLValues corresponding to the (column) identifiers of this constraint
/// system appearing in the order the identifiers correspond to columns.
/// Temporary ones or those that aren't associated to any MLValue are to be
/// set to None.
SmallVector<Optional<MLValue *>, 8> ids;
}; };
} // end namespace mlir. } // end namespace mlir.

View File

@ -25,9 +25,15 @@
#ifndef MLIR_ANALYSIS_UTILS_H #ifndef MLIR_ANALYSIS_UTILS_H
#define MLIR_ANALYSIS_UTILS_H #define MLIR_ANALYSIS_UTILS_H
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include <memory>
namespace mlir { namespace mlir {
class FlatAffineConstraints; class FlatAffineConstraints;
class MLValue;
class OperationStmt; class OperationStmt;
class Statement; class Statement;
@ -37,8 +43,69 @@ bool dominates(const Statement &a, const Statement &b);
/// Returns true if statement 'a' properly dominates statement b. /// Returns true if statement 'a' properly dominates statement b.
bool properlyDominates(const Statement &a, const Statement &b); bool properlyDominates(const Statement &a, const Statement &b);
/// Returns the memory region accessed by this memref. /// A region of a memref's data space; this is typically constructed by
bool getMemoryRegion(OperationStmt *opStmt, FlatAffineConstraints *region); /// analyzing load/store op's on this memref and the index space of loops
/// surrounding such op's.
// For example, the memref region for a load operation at loop depth = 1:
//
// for %i = 0 to 32 {
// for %ii = %i to (d0) -> (d0 + 8) (%i) {
// load %A[%ii]
// }
// }
//
// Region: {memref = %A, write = false, {%i <= m0 <= %i + 7} }
// The last field is a 2-d FlatAffineConstraints symbolic in %i.
//
struct MemRefRegion {
FlatAffineConstraints *getConstraints() { return &cst; }
const FlatAffineConstraints *getConstraints() const { return &cst; }
bool isWrite() const { return write; }
void setWrite(bool flag) { write = flag; }
// Computes the shape if the extents are known constants, returns false
// otherwise.
bool getConstantShape(llvm::SmallVectorImpl<int> *shape) const;
// Returns the number of elements in this region if it's a known constant. We
// use int64_t instead of uint64_t since index types can be at most int64_t.
Optional<int64_t> getConstantSize() const;
/// Memref that this region corresponds to.
MLValue *memref;
private:
/// Read or write.
bool write;
/// Region (data space) of the memref accessed. This set will thus have at
/// least as many dimensional identifiers as the shape dimensionality of the
/// memref, and these are the leading dimensions of the set appearing in that
/// order (major to minor / outermost to innermost). There may be additional
/// identifiers since getMemRefRegion() is called with a specific loop depth,
/// and thus the region is symbolic in the outer surrounding loops at that
/// depth.
// TODO(bondhugula): Replace this to exploit HyperRectangularSet.
FlatAffineConstraints cst;
};
/// Computes the memory region accessed by this memref with the region
/// represented as constraints symbolic/parameteric in 'loopDepth' loops
/// surrounding opStmt. Returns false if this fails due to yet unimplemented
/// cases.
// For example, the memref region for this operation at loopDepth = 1 will be:
//
// for %i = 0 to 32 {
// for %ii = %i to (d0) -> (d0 + 8) (%i) {
// load %A[%ii]
// }
// }
//
// {memref = %A, write = false, {%i <= m0 <= %i + 7} }
// The last field is a 2-d FlatAffineConstraints symbolic in %i.
//
bool getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
MemRefRegion *region);
} // end namespace mlir } // end namespace mlir

View File

@ -43,15 +43,17 @@ class SSAValue;
/// Replace all uses of oldMemRef with newMemRef while optionally remapping the /// Replace all uses of oldMemRef with newMemRef while optionally remapping the
/// old memref's indices using the supplied affine map and adding any additional /// old memref's indices using the supplied affine map and adding any additional
/// indices. The new memref could be of a different shape or rank. Returns true /// indices. The new memref could be of a different shape or rank. An optional
/// on success and false if the replacement is not possible (whenever a memref /// argument 'domOpFilter' restricts the replacement to only those operations
/// is used as an operand in a non-deferencing scenario). /// that are dominated by the former. Returns true on success and false if the
/// Additional indices are added at the start. /// replacement is not possible (whenever a memref is used as an operand in a
/// non-deferencing scenario). Additional indices are added at the start.
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily // TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
// extended to add additional indices at any position. // extended to add additional indices at any position.
bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef, bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef,
llvm::ArrayRef<MLValue *> extraIndices = {}, llvm::ArrayRef<MLValue *> extraIndices = {},
AffineMap indexRemap = AffineMap::Null()); AffineMap indexRemap = AffineMap::Null(),
const Statement *domStmtFilter = nullptr);
/// Creates and inserts into 'builder' a new AffineApplyOp, with the number of /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
/// its results equal to the number of operands, as a composition /// its results equal to the number of operands, as a composition
@ -64,7 +66,7 @@ OperationStmt *
createComposedAffineApplyOp(FuncBuilder *builder, Location loc, createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
ArrayRef<MLValue *> operands, ArrayRef<MLValue *> operands,
ArrayRef<OperationStmt *> affineApplyOps, ArrayRef<OperationStmt *> affineApplyOps,
SmallVectorImpl<SSAValue *> &results); SmallVectorImpl<SSAValue *> *results);
/// Given an operation statement, inserts a new single affine apply operation, /// Given an operation statement, inserts a new single affine apply operation,
/// that is exclusively used by this operation statement, and that provides all /// that is exclusively used by this operation statement, and that provides all

View File

@ -897,7 +897,7 @@ static void computeDirectionVector(
dependenceDomain->addDimId(j); dependenceDomain->addDimId(j);
} }
// Add equality contraints for each common loop, setting newly instroduced // Add equality contraints for each common loop, setting newly introduced
// variable at column 'j' to the 'dst' IV minus the 'src IV. // variable at column 'j' to the 'dst' IV minus the 'src IV.
SmallVector<int64_t, 4> eq; SmallVector<int64_t, 4> eq;
eq.resize(dependenceDomain->getNumCols()); eq.resize(dependenceDomain->getNumCols());

View File

@ -26,6 +26,7 @@
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h" #include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLValue.h" #include "mlir/IR/MLValue.h"
#include "mlir/IR/Statements.h"
#include "mlir/Support/MathExtras.h" #include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
@ -480,6 +481,10 @@ FlatAffineConstraints::FlatAffineConstraints(
numSymbols = other.getNumSymbolIds(); numSymbols = other.getNumSymbolIds();
numIds = other.getNumIds(); numIds = other.getNumIds();
auto otherIds = other.getIds();
ids.reserve(numReservedCols);
ids.insert(ids.end(), otherIds.begin(), otherIds.end());
unsigned numReservedEqualities = other.getNumReservedEqualities(); unsigned numReservedEqualities = other.getNumReservedEqualities();
unsigned numReservedInequalities = other.getNumReservedInequalities(); unsigned numReservedInequalities = other.getNumReservedInequalities();
@ -506,6 +511,7 @@ FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
numSymbols(set.getNumSymbols()) { numSymbols(set.getNumSymbols()) {
equalities.reserve(set.getNumEqualities() * numReservedCols); equalities.reserve(set.getNumEqualities() * numReservedCols);
inequalities.reserve(set.getNumInequalities() * numReservedCols); inequalities.reserve(set.getNumInequalities() * numReservedCols);
ids.resize(numIds, None);
for (unsigned i = 0, e = set.getNumConstraints(); i < e; ++i) { for (unsigned i = 0, e = set.getNumConstraints(); i < e; ++i) {
AffineExpr expr = set.getConstraint(i); AffineExpr expr = set.getConstraint(i);
@ -525,7 +531,8 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities,
unsigned numReservedEqualities, unsigned numReservedEqualities,
unsigned newNumReservedCols, unsigned newNumReservedCols,
unsigned newNumDims, unsigned newNumSymbols, unsigned newNumDims, unsigned newNumSymbols,
unsigned newNumLocals) { unsigned newNumLocals,
ArrayRef<MLValue *> idArgs) {
assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
"minimum 1 column"); "minimum 1 column");
numReservedCols = newNumReservedCols; numReservedCols = newNumReservedCols;
@ -538,12 +545,20 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities,
equalities.reserve(newNumReservedCols * numReservedEqualities); equalities.reserve(newNumReservedCols * numReservedEqualities);
if (numReservedInequalities >= 1) if (numReservedInequalities >= 1)
inequalities.reserve(newNumReservedCols * numReservedInequalities); inequalities.reserve(newNumReservedCols * numReservedInequalities);
ids.clear();
if (idArgs.empty()) {
ids.resize(numIds, None);
} else {
ids.reserve(idArgs.size());
ids.insert(ids.end(), idArgs.begin(), idArgs.end());
}
} }
void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols, void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
unsigned newNumLocals) { unsigned newNumLocals,
ArrayRef<MLValue *> idArgs) {
reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
newNumSymbols, newNumLocals); newNumSymbols, newNumLocals, idArgs);
} }
void FlatAffineConstraints::append(const FlatAffineConstraints &other) { void FlatAffineConstraints::append(const FlatAffineConstraints &other) {
@ -567,8 +582,8 @@ void FlatAffineConstraints::addLocalId(unsigned pos) {
addId(IdKind::Local, pos); addId(IdKind::Local, pos);
} }
void FlatAffineConstraints::addDimId(unsigned pos) { void FlatAffineConstraints::addDimId(unsigned pos, MLValue *id) {
addId(IdKind::Dimension, pos); addId(IdKind::Dimension, pos, id);
} }
void FlatAffineConstraints::addSymbolId(unsigned pos) { void FlatAffineConstraints::addSymbolId(unsigned pos) {
@ -577,7 +592,7 @@ void FlatAffineConstraints::addSymbolId(unsigned pos) {
/// Adds a dimensional identifier. The added column is initialized to /// Adds a dimensional identifier. The added column is initialized to
/// zero. /// zero.
void FlatAffineConstraints::addId(IdKind kind, unsigned pos) { void FlatAffineConstraints::addId(IdKind kind, unsigned pos, MLValue *id) {
if (kind == IdKind::Dimension) { if (kind == IdKind::Dimension) {
assert(pos <= getNumDimIds()); assert(pos <= getNumDimIds());
} else if (kind == IdKind::Symbol) { } else if (kind == IdKind::Symbol) {
@ -595,16 +610,16 @@ void FlatAffineConstraints::addId(IdKind kind, unsigned pos) {
numReservedCols++; numReservedCols++;
} }
unsigned elimPos; unsigned absolutePos;
if (kind == IdKind::Dimension) { if (kind == IdKind::Dimension) {
elimPos = pos; absolutePos = pos;
numDims++; numDims++;
} else if (kind == IdKind::Symbol) { } else if (kind == IdKind::Symbol) {
elimPos = pos + getNumDimIds(); absolutePos = pos + getNumDimIds();
numSymbols++; numSymbols++;
} else { } else {
elimPos = pos + getNumDimIds() + getNumSymbolIds(); absolutePos = pos + getNumDimIds() + getNumSymbolIds();
} }
numIds++; numIds++;
@ -615,41 +630,53 @@ void FlatAffineConstraints::addId(IdKind kind, unsigned pos) {
int numCols = static_cast<int>(getNumCols()); int numCols = static_cast<int>(getNumCols());
for (int r = numInequalities - 1; r >= 0; r--) { for (int r = numInequalities - 1; r >= 0; r--) {
for (int c = numCols - 2; c >= 0; c--) { for (int c = numCols - 2; c >= 0; c--) {
if (c < elimPos) if (c < absolutePos)
atIneq(r, c) = inequalities[r * oldNumReservedCols + c]; atIneq(r, c) = inequalities[r * oldNumReservedCols + c];
else else
atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c]; atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c];
} }
atIneq(r, elimPos) = 0; atIneq(r, absolutePos) = 0;
} }
for (int r = numEqualities - 1; r >= 0; r--) { for (int r = numEqualities - 1; r >= 0; r--) {
for (int c = numCols - 2; c >= 0; c--) { for (int c = numCols - 2; c >= 0; c--) {
// All values in column elimPositions < elimPos have the same coordinates // All values in column absolutePositions < absolutePos have the same
// in the 2-d view of the coefficient buffer. // coordinates in the 2-d view of the coefficient buffer.
if (c < elimPos) if (c < absolutePos)
atEq(r, c) = equalities[r * oldNumReservedCols + c]; atEq(r, c) = equalities[r * oldNumReservedCols + c];
else else
// Those at elimPosition >= elimPos, get a shifted elimPosition. // Those at absolutePosition >= absolutePos, get a shifted
// absolutePosition.
atEq(r, c + 1) = equalities[r * oldNumReservedCols + c]; atEq(r, c + 1) = equalities[r * oldNumReservedCols + c];
} }
// Initialize added dimension to zero. // Initialize added dimension to zero.
atEq(r, elimPos) = 0; atEq(r, absolutePos) = 0;
} }
// If an 'id' is provided, insert it; otherwise use None.
if (id) {
ids.insert(ids.begin() + absolutePos, id);
} else {
ids.insert(ids.begin() + absolutePos, None);
}
assert(ids.size() == getNumIds());
} }
// This routine may add additional local variables if the flattened // This routine may add additional local variables if the flattened
// expression corresponding to the map has such variables due to the presence of // expression corresponding to the map has such variables due to the presence of
// mod's, ceildiv's, and floordiv's. // mod's, ceildiv's, and floordiv's.
void FlatAffineConstraints::composeMap(AffineValueMap *vMap, unsigned pos) { void FlatAffineConstraints::composeMap(AffineValueMap *vMap, unsigned pos) {
assert(vMap->getNumOperands() == getNumIds() && "inconsistent map");
assert(vMap->getNumDims() == getNumDimIds() && "inconsistent map");
assert(pos <= getNumIds() && "invalid position"); assert(pos <= getNumIds() && "invalid position");
assert(vMap->getNumSymbols() == getNumSymbolIds());
AffineMap map = vMap->getAffineMap(); AffineMap map = vMap->getAffineMap();
// We add one equality for each result connecting the result dim of the map to // We add one equality for each result connecting the result dim of the map to
// the other identifiers. // the other identifiers.
// For eg: if the expression is 16*i0 + i1, and this is the r^th
// iteration/result of the value map, we are adding the equality:
// d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
// add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
for (unsigned r = 0, e = map.getNumResults(); r < e; r++) { for (unsigned r = 0, e = map.getNumResults(); r < e; r++) {
// Add dimension. // Add dimension.
addDimId(pos + r); addDimId(pos + r);
@ -660,44 +687,60 @@ void FlatAffineConstraints::composeMap(AffineValueMap *vMap, unsigned pos) {
map.getNumSymbols(), &eq, &cst); map.getNumSymbols(), &eq, &cst);
(void)ret; (void)ret;
assert(ret && "unimplemented for semi-affine maps"); assert(ret && "unimplemented for semi-affine maps");
for (unsigned j = 0, e = eq.size(); j < e; j++) {
eq[j] = -eq[j];
}
// Make the value map and the flat affine cst dimensions compatible. // Make the value map and the flat affine cst dimensions compatible.
// A lot of this code will be refactored/cleaned up. // A lot of this code will be refactored/cleaned up.
for (unsigned l = 0, e = cst.getNumLocalIds(); l < e; l++) { for (unsigned l = 0, e = cst.getNumLocalIds(); l < e; l++) {
addLocalId(getNumLocalIds()); addLocalId(0);
} }
// TODO(andydavis,bondhugula,ntv): we need common code to merge // TODO(andydavis,bondhugula,ntv): we need common code to merge
// dimensions/symbols. // dimensions/symbols.
assert(cst.getNumDimIds() <= getNumIds()); for (unsigned t = 0, e = r + 1; t < e; t++) {
for (unsigned t = 0, e = getNumDimIds() - cst.getNumDimIds(); t < e; t++) { // TODO: Consider using a batched version to add a range of IDs.
cst.addDimId(0); cst.addDimId(0);
eq.insert(eq.begin(), 0);
} }
// Set the ceofficient for this result to one.
eq[r] = 1; assert(cst.getNumDimIds() <= getNumDimIds());
// TODO(andydavis,bondhugula,ntv): we need common code to merge for (unsigned t = 0, e = getNumDimIds() - cst.getNumDimIds(); t < e; t++) {
// dimensions/symbols. cst.addDimId(cst.getNumDimIds() - 1);
assert(cst.getNumSymbolIds() <= getNumSymbolIds());
for (unsigned t = 0, e = getNumSymbolIds() - cst.getNumSymbolIds(); t < e;
t++) {
eq.insert(eq.begin() + cst.getNumSymbolIds(), 0);
cst.addSymbolId(cst.getNumSymbolIds());
} }
// TODO(andydavis,bondhugula,ntv): we need common code to merge // TODO(andydavis,bondhugula,ntv): we need common code to merge
// identifiers. All of this will be cleaned up. At this point, it's fine as // identifiers. All of this will be cleaned up. At this point, it's fine as
// long as it stays *inside* the FlatAffineConstraints API methods. // long as it stays *inside* the FlatAffineConstraints API methods.
assert(cst.getNumSymbolIds() <= getNumSymbolIds()); assert(cst.getNumLocalIds() <= getNumLocalIds());
for (unsigned t = 0, e = getNumLocalIds() - cst.getNumLocalIds(); t < e; for (unsigned t = 0, e = getNumLocalIds() - cst.getNumLocalIds(); t < e;
t++) { t++) {
eq.insert(eq.begin() + cst.getNumDimIds() + cst.getNumSymbolIds(), 0); cst.addLocalId(cst.getNumLocalIds());
cst.addLocalId(0);
} }
/// Finally, append cst to this constraint set. /// Finally, append cst to this constraint set.
append(cst); append(cst);
// eqToAdd is the equality corresponding to the flattened affine expression.
SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
// Set the coefficient for this result to one.
eqToAdd[r] = 1;
// Dims and symbols.
for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
unsigned loc;
bool ret = findId(*cast<MLValue>(vMap->getOperand(i)), &loc);
assert(ret && "id expected, but not found");
(void)ret;
// We need to negate 'eq' since the newly added dimension is going to be
// set to this one.
eqToAdd[loc] = -eq[i];
}
// Local vars common to eq and cst are at the beginning.
int j = getNumDimIds() + getNumSymbolIds();
int end = eq.size() - 1;
for (int i = vMap->getNumOperands(); i < end; i++, j++) {
eqToAdd[j] = -eq[i];
}
// Constant term.
eqToAdd[getNumCols() - 1] = -eq[eq.size() - 1];
// Add the equality connecting the result of the map to this constraint set. // Add the equality connecting the result of the map to this constraint set.
addEquality(eq); addEquality(eqToAdd);
} }
} }
@ -858,6 +901,7 @@ void FlatAffineConstraints::removeColumnRange(unsigned colStart,
numDims -= numDimsEliminated; numDims -= numDimsEliminated;
numSymbols -= numSymbolsEliminated; numSymbols -= numSymbolsEliminated;
numIds = numIds - numColsEliminated; numIds = numIds - numColsEliminated;
ids.erase(ids.begin() + colStart, ids.begin() + colLimit);
// No resize necessary. numReservedCols remains the same. // No resize necessary. numReservedCols remains the same.
} }
@ -1071,6 +1115,90 @@ void FlatAffineConstraints::addUpperBound(ArrayRef<int64_t> expr,
} }
} }
bool FlatAffineConstraints::findId(const MLValue &operand, unsigned *pos) {
unsigned i = 0;
for (const auto &mayBeId : ids) {
if (mayBeId.hasValue() && mayBeId.getValue() == &operand) {
*pos = i;
return true;
}
i++;
}
return false;
}
// TODO(andydavis, bondhugula) AFFINE REFACTOR: merge with loop bounds
// code in dependence analysis.
bool FlatAffineConstraints::addBoundsFromForStmt(unsigned pos,
ForStmt *forStmt) {
// Adds a lower or upper bound when the bounds aren't constant.
auto addLowerOrUpperBound = [&](bool lower) -> bool {
const auto &operands = lower ? forStmt->getLowerBoundOperands()
: forStmt->getUpperBoundOperands();
SmallVector<unsigned, 8> positions;
for (const auto &operand : operands) {
unsigned loc;
// TODO(andydavis, bondhugula) AFFINE REFACTOR: merge with loop bounds
// code in dependence analysis.
if (!findId(*operand, &loc)) {
addDimId(getNumDimIds(), operand);
loc = getNumDimIds() - 1;
}
positions.push_back(loc);
}
auto boundMap =
lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
for (auto result : boundMap.getResults()) {
SmallVector<int64_t, 4> flattenedExpr;
SmallVector<int64_t, 4> ineq(getNumCols(), 0);
// TODO(andydavis, bondhugula) AFFINE REFACTOR: merge with loop bounds in
// dependence analysis.
FlatAffineConstraints cst;
if (!getFlattenedAffineExpr(result, boundMap.getNumDims(),
boundMap.getNumSymbols(), &flattenedExpr,
&cst)) {
LLVM_DEBUG(llvm::dbgs()
<< "semi-affine expressions not yet supported\n");
return false;
}
if (cst.getNumLocalIds() > 0) {
LLVM_DEBUG(
llvm::dbgs()
<< "loop bounds with mod/floordiv expr's not yet supported\n");
return false;
}
ineq[pos] = lower ? 1 : -1;
for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
ineq[positions[j]] = lower ? -flattenedExpr[j] : flattenedExpr[j];
}
// Constant term.
ineq[getNumCols() - 1] = lower ? -flattenedExpr[flattenedExpr.size() - 1]
: flattenedExpr[flattenedExpr.size() - 1];
addInequality(ineq);
}
return true;
};
if (forStmt->hasConstantLowerBound()) {
addConstantLowerBound(pos, forStmt->getConstantLowerBound());
} else {
// Non-constant lower bound case.
if (!addLowerOrUpperBound(/*lower=*/true))
return false;
}
if (forStmt->hasConstantUpperBound()) {
addConstantUpperBound(pos, forStmt->getConstantUpperBound() - 1);
return true;
}
// Non-constant upper bound case.
return addLowerOrUpperBound(/*lower=*/false);
}
/// Sets the specified identifer to a constant value. /// Sets the specified identifer to a constant value.
void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) {
unsigned offset = equalities.size(); unsigned offset = equalities.size();
@ -1119,7 +1247,8 @@ bool FlatAffineConstraints::getDimensionBounds(unsigned pos, unsigned num,
return true; return true;
} }
Optional<int64_t> FlatAffineConstraints::getConstantLowerBound(unsigned pos) { Optional<int64_t>
FlatAffineConstraints::getConstantLowerBound(unsigned pos) const {
assert(pos < getNumCols() - 1); assert(pos < getNumCols() - 1);
Optional<int64_t> lb = None; Optional<int64_t> lb = None;
for (unsigned r = 0; r < getNumInequalities(); r++) { for (unsigned r = 0; r < getNumInequalities(); r++) {
@ -1143,7 +1272,71 @@ Optional<int64_t> FlatAffineConstraints::getConstantLowerBound(unsigned pos) {
return lb; return lb;
} }
Optional<int64_t> FlatAffineConstraints::getConstantUpperBound(unsigned pos) { /// Returns the extent of the specified identifier (upper bound - lower bound)
/// if it found to be a constant; returns None if it's not a constant.
/// 'lbPosition' is set to the row position of the corresponding lower bound.
Optional<int64_t>
FlatAffineConstraints::getConstantBoundDifference(unsigned pos,
unsigned *lbPosition) const {
// Check if the identifier appears at all in any of the inequalities.
unsigned r, e;
for (r = 0, e = getNumInequalities(); r < e; r++) {
if (atIneq(r, pos) != 0)
break;
}
if (r == e) {
// If it doesn't appear, just remove the column and return.
// TODO(andydavis,bondhugula): refactor removeColumns to use it from here.
return None;
}
// Positions of constraints that are lower/upper bounds on the variable.
SmallVector<unsigned, 4> lbIndices, ubIndices;
// Gather all lower bounds and upper bounds of the variable. Since the
// canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
// bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
if (atIneq(r, pos) >= 1)
// Lower bound.
lbIndices.push_back(r);
else if (atIneq(r, pos) <= -1)
// Upper bound.
ubIndices.push_back(r);
}
// TODO(bondhugula): eliminate all variables that aren't part of any of the
// lower/upper bounds - to make this more powerful.
Optional<int64_t> minDiff = None;
for (auto ubPos : ubIndices) {
for (auto lbPos : lbIndices) {
// Look for a lower bound and an upper bound that only differ by a
// constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst.
// For example, if ii is the pos^th variable, we are looking for
// constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
// minimum among all such constant differences is kept since that's the
// constant bounding the extent of the pos^th variable.
unsigned j;
for (j = 0; j < getNumCols() - 1; j++)
if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
break;
}
if (j < getNumCols() - 1)
continue;
int64_t mayDiff =
atIneq(ubPos, getNumCols() - 1) + atIneq(lbPos, getNumCols() - 1) + 1;
if (minDiff == None || mayDiff < minDiff) {
minDiff = mayDiff;
*lbPosition = lbPos;
}
}
}
return minDiff;
}
Optional<int64_t>
FlatAffineConstraints::getConstantUpperBound(unsigned pos) const {
assert(pos < getNumCols() - 1); assert(pos < getNumCols() - 1);
Optional<int64_t> ub = None; Optional<int64_t> ub = None;
for (unsigned r = 0; r < getNumInequalities(); r++) { for (unsigned r = 0; r < getNumInequalities(); r++) {
@ -1196,8 +1389,17 @@ bool FlatAffineConstraints::isHyperRectangular(unsigned pos,
void FlatAffineConstraints::print(raw_ostream &os) const { void FlatAffineConstraints::print(raw_ostream &os) const {
assert(inequalities.size() == getNumInequalities() * numReservedCols); assert(inequalities.size() == getNumInequalities() * numReservedCols);
assert(equalities.size() == getNumEqualities() * numReservedCols); assert(equalities.size() == getNumEqualities() * numReservedCols);
assert(ids.size() == getNumIds());
os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds() os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds()
<< " symbols, " << getNumLocalIds() << " locals): \n"; << " symbols, " << getNumLocalIds() << " locals): \n";
os << "(";
for (unsigned i = 0, e = getNumIds(); i < e; i++) {
if (ids[i] == None)
os << "None ";
else
os << "MLValue ";
}
os << ")\n";
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
for (unsigned j = 0; j < getNumCols(); ++j) { for (unsigned j = 0; j < getNumCols(); ++j) {
os << atEq(i, j) << " "; os << atEq(i, j) << " ";
@ -1223,6 +1425,7 @@ void FlatAffineConstraints::clearAndCopyFrom(
const FlatAffineConstraints &other) { const FlatAffineConstraints &other) {
FlatAffineConstraints copy(other); FlatAffineConstraints copy(other);
std::swap(*this, copy); std::swap(*this, copy);
assert(copy.getNumIds() == copy.getIds().size());
} }
void FlatAffineConstraints::removeId(unsigned pos) { void FlatAffineConstraints::removeId(unsigned pos) {
@ -1245,6 +1448,7 @@ void FlatAffineConstraints::removeId(unsigned pos) {
atEq(r, c) = atEq(r, c + 1); atEq(r, c) = atEq(r, c + 1);
} }
} }
ids.erase(ids.begin() + pos);
} }
static std::pair<unsigned, unsigned> static std::pair<unsigned, unsigned>
@ -1375,11 +1579,18 @@ void FlatAffineConstraints::FourierMotzkinEliminate(
unsigned newNumDims = dimsSymbols.first; unsigned newNumDims = dimsSymbols.first;
unsigned newNumSymbols = dimsSymbols.second; unsigned newNumSymbols = dimsSymbols.second;
SmallVector<Optional<MLValue *>, 8> newIds;
newIds.reserve(numIds - 1);
newIds.insert(newIds.end(), ids.begin(), ids.begin() + pos);
newIds.insert(newIds.end(), ids.begin() + pos + 1, ids.end());
/// Create the new system which has one identifier less. /// Create the new system which has one identifier less.
FlatAffineConstraints newFac( FlatAffineConstraints newFac(
lbIndices.size() * ubIndices.size() + nbIndices.size(), lbIndices.size() * ubIndices.size() + nbIndices.size(),
getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols, getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols,
/*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols); /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds);
assert(newFac.getIds().size() == newFac.getNumIds());
// This will be used to check if the elimination was integer exact. // This will be used to check if the elimination was integer exact.
unsigned lcmProducts = 1; unsigned lcmProducts = 1;
@ -1462,9 +1673,19 @@ void FlatAffineConstraints::FourierMotzkinEliminate(
void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
// 'pos' can be at most getNumCols() - 2. // 'pos' can be at most getNumCols() - 2.
if (num == 0)
return;
assert(pos <= getNumCols() - 2 && "invalid position"); assert(pos <= getNumCols() - 2 && "invalid position");
assert(pos + num < getNumCols() && "invalid range"); assert(pos + num < getNumCols() && "invalid range");
for (unsigned i = 0; i < num; i++) { for (unsigned i = 0; i < num; i++) {
FourierMotzkinEliminate(pos); FourierMotzkinEliminate(pos);
} }
} }
void FlatAffineConstraints::projectOut(MLValue *id) {
unsigned pos;
bool ret = findId(*id, &pos);
assert(ret);
(void)ret;
FourierMotzkinEliminate(pos);
}

View File

@ -63,15 +63,15 @@ void MemRefBoundCheck::visitOperationStmt(OperationStmt *opStmt) {
// TODO(bondhugula): extend this to store's and other memref dereferencing // TODO(bondhugula): extend this to store's and other memref dereferencing
// op's. // op's.
if (auto loadOp = opStmt->dyn_cast<LoadOp>()) { if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
FlatAffineConstraints memoryRegion; MemRefRegion region;
if (!getMemoryRegion(opStmt, &memoryRegion)) if (!getMemRefRegion(opStmt, /*loopDepth=*/0, &region))
return; return;
LLVM_DEBUG(llvm::dbgs() << "Memory region"); LLVM_DEBUG(llvm::dbgs() << "Memory region");
LLVM_DEBUG(memoryRegion.dump()); LLVM_DEBUG(region.getConstraints()->dump());
unsigned rank = loadOp->getMemRefType().getRank(); unsigned rank = loadOp->getMemRefType().getRank();
// For each dimension, check for out of bounds. // For each dimension, check for out of bounds.
for (unsigned r = 0; r < rank; r++) { for (unsigned r = 0; r < rank; r++) {
FlatAffineConstraints ucst(memoryRegion); FlatAffineConstraints ucst(*region.getConstraints());
// Intersect memory region with constraint capturing out of bounds, // Intersect memory region with constraint capturing out of bounds,
// and check if the constraint system is feasible. If it is, there is at // and check if the constraint system is feasible. If it is, there is at
// least one point out of bounds. // least one point out of bounds.
@ -91,7 +91,7 @@ void MemRefBoundCheck::visitOperationStmt(OperationStmt *opStmt) {
Twine(r + 1)); Twine(r + 1));
} }
// Check for less than negative index. // Check for less than negative index.
FlatAffineConstraints lcst(memoryRegion); FlatAffineConstraints lcst(*region.getConstraints());
std::fill(ineq.begin(), ineq.end(), 0); std::fill(ineq.begin(), ineq.end(), 0);
// d_i <= -1; // d_i <= -1;
lcst.addConstantUpperBound(r, -1); lcst.addConstantUpperBound(r, -1);

View File

@ -27,6 +27,9 @@
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/StandardOps/StandardOps.h" #include "mlir/StandardOps/StandardOps.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "analysis-utils"
using namespace mlir; using namespace mlir;
@ -65,62 +68,141 @@ bool mlir::dominates(const Statement &a, const Statement &b) {
return &a == &b || properlyDominates(a, b); return &a == &b || properlyDominates(a, b);
} }
/// Returns the memory region accessed by this memref. Optional<int64_t> MemRefRegion::getConstantSize() const {
// TODO(bondhugula): extend this to store's and other memref dereferencing ops. auto memRefType = memref->getType().cast<MemRefType>();
bool mlir::getMemoryRegion(OperationStmt *opStmt, unsigned rank = memRefType.getRank();
FlatAffineConstraints *region) {
OpPointer<LoadOp> loadOp; // Compute the extents of the buffer.
if (!(loadOp = opStmt->dyn_cast<LoadOp>())) int64_t numElements = 1;
return false; for (unsigned d = 0; d < rank; d++) {
unsigned lbPos;
Optional<int64_t> diff = cst.getConstantBoundDifference(d, &lbPos);
if (!diff.hasValue())
return None;
int64_t diffConstant = diff.getValue();
if (diffConstant <= 0)
return 0;
numElements *= diffConstant;
}
return numElements;
}
bool MemRefRegion::getConstantShape(SmallVectorImpl<int> *shape) const {
auto memRefType = memref->getType().cast<MemRefType>();
unsigned rank = memRefType.getRank();
shape->reserve(rank);
// Compute the extents of this memref region.
for (unsigned d = 0; d < rank; d++) {
unsigned lbPos;
Optional<int64_t> diff = cst.getConstantBoundDifference(d, &lbPos);
if (!diff.hasValue())
return false;
int diffConstant = std::max(0L, diff.getValue());
shape->push_back(diffConstant);
}
return true;
}
/// Computes the memory region accessed by this memref with the region
/// represented as constraints symbolic/parameteric in 'loopDepth' loops
/// surrounding opStmt. Returns false if this fails due to yet unimplemented
/// cases.
// For example, the memref region for this load operation at loopDepth = 1 will
// be as below:
//
// for %i = 0 to 32 {
// for %ii = %i to (d0) -> (d0 + 8) (%i) {
// load %A[%ii]
// }
// }
//
// region: {memref = %A, write = false, {%i <= m0 <= %i + 7} }
// The last field is a 2-d FlatAffineConstraints symbolic in %i.
//
// TODO(bondhugula): extend this to any other memref dereferencing ops
// (dma_start, dma_wait).
bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
MemRefRegion *region) {
OpPointer<LoadOp> loadOp;
OpPointer<StoreOp> storeOp;
unsigned rank;
SmallVector<MLValue *, 4> indices;
if ((loadOp = opStmt->dyn_cast<LoadOp>())) {
rank = loadOp->getMemRefType().getRank();
for (auto *index : loadOp->getIndices()) {
indices.push_back(cast<MLValue>(index));
}
region->memref = cast<MLValue>(loadOp->getMemRef());
region->setWrite(false);
} else if ((storeOp = opStmt->dyn_cast<StoreOp>())) {
rank = storeOp->getMemRefType().getRank();
for (auto *index : storeOp->getIndices()) {
indices.push_back(cast<MLValue>(index));
}
region->memref = cast<MLValue>(storeOp->getMemRef());
region->setWrite(true);
} else {
return false;
}
// Build the constraints for this region.
FlatAffineConstraints *regionCst = region->getConstraints();
unsigned rank = loadOp->getMemRefType().getRank();
MLFuncBuilder b(opStmt); MLFuncBuilder b(opStmt);
auto idMap = b.getMultiDimIdentityMap(rank); auto idMap = b.getMultiDimIdentityMap(rank);
SmallVector<MLValue *, 4> indices; // Initialize 'accessValueMap' and compose with reachable AffineApplyOps.
for (auto *index : loadOp->getIndices()) { AffineValueMap accessValueMap(idMap, indices);
indices.push_back(cast<MLValue>(index)); forwardSubstituteReachableOps(&accessValueMap);
} AffineMap accessMap = accessValueMap.getAffineMap();
// Initialize 'accessMap' and compose with reachable AffineApplyOps. regionCst->reset(accessMap.getNumDims(), accessMap.getNumSymbols(), 0,
AffineValueMap accessMap(idMap, indices); accessValueMap.getOperands());
forwardSubstituteReachableOps(&accessMap);
AffineMap srcMap = accessMap.getAffineMap();
region->reset(srcMap.getNumDims(), srcMap.getNumSymbols());
// Add equality constraints. // Add equality constraints.
AffineMap map = accessMap.getAffineMap(); unsigned numDims = accessMap.getNumDims();
unsigned numDims = map.getNumDims(); unsigned numSymbols = accessMap.getNumSymbols();
unsigned numSymbols = map.getNumSymbols(); // Add inequalties for loop lower/upper bounds.
// Add inEqualties for loop lower/upper bounds.
for (unsigned i = 0; i < numDims + numSymbols; ++i) { for (unsigned i = 0; i < numDims + numSymbols; ++i) {
if (auto *loop = dyn_cast<ForStmt>(accessMap.getOperand(i))) { if (auto *loop = dyn_cast<ForStmt>(accessValueMap.getOperand(i))) {
if (!loop->hasConstantBounds()) // Note that regionCst can now have more dimensions than accessMap if the
// bounds expressions involve outer loops or other symbols.
if (!regionCst->addBoundsFromForStmt(i, loop))
return false; return false;
// Add lower bound and upper bounds.
region->addConstantLowerBound(i, loop->getConstantLowerBound());
region->addConstantUpperBound(i, loop->getConstantUpperBound() - 1);
} else { } else {
// Has to be a valid symbol. // Has to be a valid symbol.
auto *symbol = cast<MLValue>(accessMap.getOperand(i)); auto *symbol = cast<MLValue>(accessValueMap.getOperand(i));
assert(symbol->isValidSymbol()); assert(symbol->isValidSymbol());
// Check if the symbols is a constant. // Check if the symbols is a constant.
if (auto *opStmt = symbol->getDefiningStmt()) { if (auto *opStmt = symbol->getDefiningStmt()) {
if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) { if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
region->setIdToConstant(i, constOp->getValue()); regionCst->setIdToConstant(i, constOp->getValue());
} }
} }
} }
} }
// Add access function equalities to connect loop IVs to data dimensions. // Add access function equalities to connect loop IVs to data dimensions.
region->composeMap(&accessMap); regionCst->composeMap(&accessValueMap);
// Eliminate the loop IVs and any local variables to yield the memory region // Eliminate the loop IVs and any local variables to yield the memory
// involving just the memref dimensions. // region involving just the memref dimensions and outer loop IVs up to
region->projectOut(srcMap.getNumResults(), // loopDepth.
accessMap.getNumOperands() + region->getNumLocalIds()); for (auto *operand : accessValueMap.getOperands()) {
assert(region->getNumDimIds() == rank); regionCst->projectOut(operand);
}
regionCst->projectOut(regionCst->getNumDimIds() +
regionCst->getNumSymbolIds(),
regionCst->getNumLocalIds());
// Tighten the set.
regionCst->GCDTightenInequalities();
assert(regionCst->getNumDimIds() >= rank);
return true; return true;
} }

View File

@ -717,6 +717,7 @@ void DmaStartOp::print(OpAsmPrinter *p) const {
*p << " : " << getSrcMemRef()->getType(); *p << " : " << getSrcMemRef()->getType();
*p << ", " << getDstMemRef()->getType(); *p << ", " << getDstMemRef()->getType();
*p << ", " << getTagMemRef()->getType(); *p << ", " << getTagMemRef()->getType();
p->printOptionalAttrDict(getAttrs());
} }
// Parse DmaStartOp. // Parse DmaStartOp.
@ -811,6 +812,7 @@ void DmaWaitOp::print(OpAsmPrinter *p) const {
*p << "], "; *p << "], ";
p->printOperand(getNumElements()); p->printOperand(getNumElements());
*p << " : " << getTagMemRef()->getType(); *p << " : " << getTagMemRef()->getType();
p->printOptionalAttrDict(getAttrs());
} }
// Parse DmaWaitOp. // Parse DmaWaitOp.

View File

@ -30,193 +30,306 @@
#include "mlir/StandardOps/StandardOps.h" #include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h" #include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include <algorithm> #include <algorithm>
#define DEBUG_TYPE "dma-generate" #define DEBUG_TYPE "dma-generate"
using namespace mlir; using namespace mlir;
static llvm::cl::opt<unsigned> clFastMemorySpace(
"dma-fast-memory-space", llvm::cl::Hidden,
llvm::cl::desc("Set fast memory space id for DMA generation"));
namespace { namespace {
// A region of memory in a lower memory space. /// Generates DMAs for memref's living in 'slowMemorySpace' into newly created
struct Region { /// buffers in 'fastMemorySpace', and replaces memory operations to the former
// Memref corresponding to the region.
MLValue *memref;
// Read or write.
bool isWrite;
// Region of memory accessed.
// TODO(bondhugula): Replace this to exploit HyperRectangularSet.
std::unique_ptr<FlatAffineConstraints> cst;
};
/// Generates DMAs for memref's living in 'lowMemorySpace' into newly created
/// buffers in 'highMemorySpace', and replaces memory operations to the former
/// by the latter. Only load op's handled for now. /// by the latter. Only load op's handled for now.
/// TODO(bondhugula): extend this to store op's. /// TODO(bondhugula): extend this to store op's.
struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> { struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> {
explicit DmaGeneration(unsigned lowMemorySpace = 0, explicit DmaGeneration(unsigned slowMemorySpace = 0,
unsigned highMemorySpace = 1, unsigned fastMemorySpaceArg = 1,
int minDmaTransferSize = 1024) int minDmaTransferSize = 1024)
: FunctionPass(&DmaGeneration::passID), lowMemorySpace(lowMemorySpace), : FunctionPass(&DmaGeneration::passID), slowMemorySpace(slowMemorySpace),
highMemorySpace(highMemorySpace), minDmaTransferSize(minDmaTransferSize) {
minDmaTransferSize(minDmaTransferSize) {} if (clFastMemorySpace.getNumOccurrences() > 0) {
fastMemorySpace = clFastMemorySpace;
} else {
fastMemorySpace = fastMemorySpaceArg;
}
}
PassResult runOnMLFunction(MLFunction *f) override;
// Not applicable to CFG functions. // Not applicable to CFG functions.
PassResult runOnCFGFunction(CFGFunction *f) override { return success(); } PassResult runOnCFGFunction(CFGFunction *f) override { return success(); }
bool runOnForStmt(ForStmt *forStmt); PassResult runOnMLFunction(MLFunction *f) override;
void runOnForStmt(ForStmt *forStmt);
void visitOperationStmt(OperationStmt *opStmt); void visitOperationStmt(OperationStmt *opStmt);
void generateDma(const Region &region, Location loc, MLFuncBuilder *b); bool generateDma(const MemRefRegion &region, ForStmt *forStmt);
// List of memory regions to promote. // List of memory regions to DMA for.
std::vector<Region> regions; std::vector<std::unique_ptr<MemRefRegion>> regions;
// Map from original memref's to the DMA buffers that their accesses are
// replaced with.
DenseMap<SSAValue *, SSAValue *> fastBufferMap;
// Slow memory space associated with DMAs.
const unsigned slowMemorySpace;
// Fast memory space associated with DMAs.
unsigned fastMemorySpace;
// Minimum DMA transfer size supported by the target in bytes.
const int minDmaTransferSize;
// The loop level at which DMAs should be generated. '0' is an outermost loop.
unsigned dmaDepth;
static char passID; static char passID;
const unsigned lowMemorySpace;
const unsigned highMemorySpace;
const int minDmaTransferSize;
}; };
} // end anonymous namespace } // end anonymous namespace
char DmaGeneration::passID = 0; char DmaGeneration::passID = 0;
/// Generates DMAs for memref's living in 'lowMemorySpace' into newly created /// Generates DMAs for memref's living in 'slowMemorySpace' into newly created
/// buffers in 'highMemorySpace', and replaces memory operations to the former /// buffers in 'fastMemorySpace', and replaces memory operations to the former
/// by the latter. Only load op's handled for now. /// by the latter. Only load op's handled for now.
/// TODO(bondhugula): extend this to store op's. /// TODO(bondhugula): extend this to store op's.
FunctionPass *mlir::createDmaGenerationPass(unsigned lowMemorySpace, FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace,
unsigned highMemorySpace, unsigned fastMemorySpace,
int minDmaTransferSize) { int minDmaTransferSize) {
return new DmaGeneration(lowMemorySpace, highMemorySpace, minDmaTransferSize); return new DmaGeneration(slowMemorySpace, fastMemorySpace,
minDmaTransferSize);
} }
// Gather regions to promote to buffers in higher memory space. // Gather regions to promote to buffers in faster memory space.
// TODO(bondhugula): handle store op's; only load's handled for now. // TODO(bondhugula): handle store op's; only load's handled for now.
void DmaGeneration::visitOperationStmt(OperationStmt *opStmt) { void DmaGeneration::visitOperationStmt(OperationStmt *opStmt) {
if (auto loadOp = opStmt->dyn_cast<LoadOp>()) { if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
if (loadOp->getMemRefType().getMemorySpace() != lowMemorySpace) if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace)
return; return;
} else if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
// TODO(bondhugula): eventually, we need to be performing a union across all if (storeOp->getMemRefType().getMemorySpace() != slowMemorySpace)
// regions for a given memref instead of creating one region per memory op.
// This way we would be allocating O(num of memref's) sets instead of
// O(num of load/store op's).
auto memoryRegion = std::make_unique<FlatAffineConstraints>();
if (!getMemoryRegion(opStmt, memoryRegion.get())) {
LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region");
return; return;
} } else {
LLVM_DEBUG(llvm::dbgs() << "Memory region"); // Neither load nor a store op.
LLVM_DEBUG(memoryRegion->dump()); return;
regions.push_back(
{cast<MLValue>(loadOp->getMemRef()), false, std::move(memoryRegion)});
} }
// TODO(bondhugula): eventually, we need to be performing a union across all
// regions for a given memref instead of creating one region per memory op.
// This way we would be allocating O(num of memref's) sets instead of
// O(num of load/store op's).
auto region = std::make_unique<MemRefRegion>();
if (!getMemRefRegion(opStmt, dmaDepth, region.get())) {
LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region\n");
return;
}
LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
LLVM_DEBUG(region->getConstraints()->dump());
regions.push_back(std::move(region));
} }
// Create a buffer in the higher (faster) memory space for the specified region; // Creates a buffer in the faster memory space for the specified region;
// generate a DMA from the lower memory space to this one, and replace all loads // generates a DMA from the lower memory space to this one, and replaces all
// to load from the buffer. // loads to load from the buffer. Returns true if DMAs are generated.
// TODO: handle write regions by generating outgoing DMAs; only read regions are bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt) {
// handled for now. // DMAs for read regions are going to be inserted just before the for loop.
void DmaGeneration::generateDma(const Region &region, Location loc, MLFuncBuilder prologue(forStmt);
MLFuncBuilder *b) { // DMAs for write regions are going to be inserted just after the for loop.
// Only memref read regions handled for now. MLFuncBuilder epilogue(forStmt->getBlock(),
if (region.isWrite) std::next(StmtBlock::iterator(forStmt)));
return; MLFuncBuilder *b = region.isWrite() ? &epilogue : &prologue;
// Builder to create constants at the top level.
MLFuncBuilder top(forStmt->findFunction());
FlatAffineConstraints *cst =
const_cast<FlatAffineConstraints *>(region.getConstraints());
auto loc = forStmt->getLoc();
auto *memref = region.memref; auto *memref = region.memref;
auto memRefType = memref->getType().cast<MemRefType>(); auto memRefType = memref->getType().cast<MemRefType>();
// Indices to use for DmaStart op.
SmallVector<SSAValue *, 4> srcIndices, destIndices; SmallVector<SSAValue *, 4> srcIndices, destIndices;
SSAValue *zeroIndex = b->create<ConstantIndexOp>(loc, 0); SSAValue *zeroIndex = top.create<ConstantIndexOp>(loc, 0);
unsigned rank = memRefType.getRank(); unsigned rank = memRefType.getRank();
SmallVector<int, 4> shape; SmallVector<int, 4> shape;
shape.reserve(rank);
// Compute the extents of the buffer.
Optional<int64_t> numElements = region.getConstantSize();
if (!numElements.hasValue()) {
LLVM_DEBUG(llvm::dbgs() << "Non-constant region size\n");
return false;
}
if (numElements.getValue() == 0) {
LLVM_DEBUG(llvm::dbgs() << "Nothing to DMA\n");
return false;
}
region.getConstantShape(&shape);
// Index start offsets for faster memory buffer relative to the original. // Index start offsets for faster memory buffer relative to the original.
SmallVector<int, 4> offsets; SmallVector<AffineExpr, 4> offsets;
offsets.reserve(rank); offsets.reserve(rank);
unsigned numElements = 1;
for (unsigned d = 0; d < rank; d++) { for (unsigned d = 0; d < rank; d++) {
auto lb = region.cst->getConstantLowerBound(d); unsigned lbPos;
auto ub = region.cst->getConstantUpperBound(d); cst->getConstantBoundDifference(d, &lbPos);
if (!lb.hasValue() || !ub.hasValue()) { // Construct the index expressions for the fast memory buffer. The index
LLVM_DEBUG(llvm::dbgs() << "Non-constant loop bounds"); // expression for a particular dimension of the fast buffer is obtained by
return; // subtracting out the lower bound on the original memref's data region
// along the corresponding dimension.
AffineExpr offset = top.getAffineConstantExpr(0);
for (unsigned j = rank; j < cst->getNumCols() - 1; j++) {
offset = offset - cst->atIneq(lbPos, j) * top.getAffineDimExpr(j - rank);
} }
offset = offset - cst->atIneq(lbPos, cst->getNumCols() - 1);
offsets.push_back(offset);
offsets.push_back(lb.getValue()); auto ids = cst->getIds();
int dimSize = ub.getValue() - lb.getValue() + 1; SmallVector<SSAValue *, 8> operands;
if (dimSize <= 0) for (unsigned i = rank, e = ids.size(); i < e; i++) {
return; auto id = cst->getIds()[i];
shape.push_back(dimSize); assert(id.hasValue());
numElements *= dimSize; operands.push_back(id.getValue());
srcIndices.push_back(b->create<ConstantIndexOp>(loc, lb.getValue())); }
// Set DMA start location for this dimension in the lower memory space
// memref.
if (auto caf = offsets[d].dyn_cast<AffineConstantExpr>()) {
srcIndices.push_back(cast<MLValue>(
top.create<ConstantIndexOp>(loc, caf.getValue())->getResult()));
} else {
auto map =
top.getAffineMap(cst->getNumDimIds() + cst->getNumSymbolIds() - rank,
0, offsets[d], {});
srcIndices.push_back(cast<MLValue>(
b->create<AffineApplyOp>(loc, map, operands)->getResult(0)));
}
// The fast buffer is DMAed into at location zero; addressing is relative.
destIndices.push_back(zeroIndex); destIndices.push_back(zeroIndex);
} }
// Create the faster memref buffer. SSAValue *fastMemRef;
auto fastMemRefType =
b->getMemRefType(shape, memRefType.getElementType(), {}, highMemorySpace);
auto fastMemRef = b->create<AllocOp>(loc, fastMemRefType)->getResult(); // Check if a buffer was already created.
// TODO(bondhugula): union across all memory op's per buffer. For now assuming
// that multiple memory op's on the same memref have the *same* memory
// footprint.
if (fastBufferMap.find(memref) == fastBufferMap.end()) {
auto fastMemRefType = top.getMemRefType(shape, memRefType.getElementType(),
{}, fastMemorySpace);
LLVM_DEBUG(llvm::dbgs() << "Creating a new buffer of type: ");
LLVM_DEBUG(fastMemRefType.dump(); llvm::dbgs() << "\n");
// Create the fast memory space buffer just before the 'for' statement.
fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType)->getResult();
// Record it.
fastBufferMap[memref] = fastMemRef;
} else {
// Reuse the one already created.
fastMemRef = fastBufferMap[memref];
}
// Create a tag (single element 1-d memref) for the DMA. // Create a tag (single element 1-d memref) for the DMA.
auto tagMemRefType = b->getMemRefType({1}, b->getIntegerType(32)); auto tagMemRefType = top.getMemRefType({1}, top.getIntegerType(32));
auto tagMemRef = b->create<AllocOp>(loc, tagMemRefType); auto tagMemRef = prologue.create<AllocOp>(loc, tagMemRefType);
auto numElementsSSA = b->create<ConstantIndexOp>(loc, numElements); auto numElementsSSA =
top.create<ConstantIndexOp>(loc, numElements.getValue());
// TODO(bondhugula): check for transfer sizes not being a multiple of // TODO(bondhugula): check for transfer sizes not being a multiple of
// minDmaTransferSize and handle them appropriately. // minDmaTransferSize and handle them appropriately.
// TODO(bondhugula): Need to use strided DMA for multi-dimensional (>= 2-d) // TODO(bondhugula): Need to use strided DMA for multi-dimensional (>= 2-d)
// case. // case.
b->create<DmaStartOp>(loc, memref, srcIndices, fastMemRef, destIndices,
numElementsSSA, tagMemRef, zeroIndex); if (!region.isWrite()) {
b->create<DmaStartOp>(loc, memref, srcIndices, fastMemRef, destIndices,
numElementsSSA, tagMemRef, zeroIndex);
} else {
// dest and src is switched for the writes (since DMA is from the faster
// memory space to the slower one).
b->create<DmaStartOp>(loc, fastMemRef, destIndices, memref, srcIndices,
numElementsSSA, tagMemRef, zeroIndex);
}
// Matching DMA wait to block on completion; tag always has a 0 index.
b->create<DmaWaitOp>(loc, tagMemRef, zeroIndex, numElementsSSA); b->create<DmaWaitOp>(loc, tagMemRef, zeroIndex, numElementsSSA);
// Replace all uses of the old memref with the promoted one while remapping // Replace all uses of the old memref with the faster one while remapping
// access indices (subtracting out lower bound offsets for each dimension). // access indices (subtracting out lower bound offsets for each dimension).
SmallVector<AffineExpr, 4> remapExprs; SmallVector<AffineExpr, 4> remapExprs;
remapExprs.reserve(rank); remapExprs.reserve(rank);
for (unsigned i = 0; i < rank; i++) { for (unsigned i = 0; i < rank; i++) {
auto d0 = b->getAffineDimExpr(i); auto dim = b->getAffineDimExpr(i);
remapExprs.push_back(d0 - offsets[i]); remapExprs.push_back(dim - offsets[i]);
} }
auto indexRemap = b->getAffineMap(rank, 0, remapExprs, {}); auto indexRemap = b->getAffineMap(rank, 0, remapExprs, {});
replaceAllMemRefUsesWith(memref, cast<MLValue>(fastMemRef), {}, indexRemap); // *Only* those uses within the body of 'forStmt' are replaced.
replaceAllMemRefUsesWith(memref, cast<MLValue>(fastMemRef), {}, indexRemap,
&*forStmt->begin());
return true;
} }
bool DmaGeneration::runOnForStmt(ForStmt *forStmt) { /// Returns the nesting depth of this statement, i.e., the number of loops
walk(forStmt); /// surrounding this statement.
// TODO(bondhugula): move this to utilities later.
static unsigned getNestingDepth(const Statement &stmt) {
const Statement *currStmt = &stmt;
unsigned depth = 0;
while ((currStmt = currStmt->getParentStmt())) {
if (isa<ForStmt>(currStmt))
depth++;
}
return depth;
}
MLFuncBuilder b(forStmt); // TODO(bondhugula): make this run on a StmtBlock instead of a 'for' stmt.
for (const auto &region : regions) { void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
generateDma(region, forStmt->getLoc(), &b); // For now (for testing purposes), we'll run this on the outermost among 'for'
// stmt's with unit stride, i.e., right at the top of the tile if tiling has
// been done. In the future, the DMA generation has to be done at a level
// where the generated data fits in a higher level of the memory hierarchy; so
// the pass has to be instantiated with additional information that we aren't
// provided with at the moment.
if (forStmt->getStep() != 1) {
if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->begin())) {
runOnForStmt(innerFor);
}
return;
} }
// This function never leaves the IR in an invalid state. // DMAs will be generated for this depth, i.e., for all data accessed by this
return false; // loop.
dmaDepth = getNestingDepth(*forStmt);
regions.clear();
fastBufferMap.clear();
// Walk this 'for' statement to gather all memory regions.
walk(forStmt);
for (const auto &region : regions) {
generateDma(*region, forStmt);
}
} }
PassResult DmaGeneration::runOnMLFunction(MLFunction *f) { PassResult DmaGeneration::runOnMLFunction(MLFunction *f) {
bool ret = false;
for (auto &stmt : *f) { for (auto &stmt : *f) {
// Run on all 'for' statements for now.
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) { if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
ret = ret | runOnForStmt(forStmt); runOnForStmt(forStmt);
} }
} }
return ret ? failure() : success(); // This function never leaves the IR in an invalid state.
return success();
} }
static PassRegistration<DmaGeneration> static PassRegistration<DmaGeneration>

View File

@ -42,7 +42,7 @@ namespace {
struct LoopTiling : public FunctionPass { struct LoopTiling : public FunctionPass {
LoopTiling() : FunctionPass(&LoopTiling::passID) {} LoopTiling() : FunctionPass(&LoopTiling::passID) {}
PassResult runOnMLFunction(MLFunction *f) override; PassResult runOnMLFunction(MLFunction *f) override;
constexpr static unsigned kDefaultTileSize = 32; constexpr static unsigned kDefaultTileSize = 4;
static char passID; static char passID;
}; };

View File

@ -117,7 +117,7 @@ static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
return true; return true;
} }
/// Returns false if this succeeds on at least one 'for' stmt. /// Returns success if the IR is in a valid state.
PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
// Do a post order walk so that inner loop DMAs are processed first. This is // Do a post order walk so that inner loop DMAs are processed first. This is
// necessary since 'for' statements nested within would otherwise become // necessary since 'for' statements nested within would otherwise become
@ -126,9 +126,9 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
// epilogue). // epilogue).
forStmts.clear(); forStmts.clear();
walkPostOrder(f); walkPostOrder(f);
bool ret = true; bool ret = false;
for (auto *forStmt : forStmts) { for (auto *forStmt : forStmts) {
ret = ret & runOnForStmt(forStmt); ret = ret | runOnForStmt(forStmt);
} }
return ret ? failure() : success(); return ret ? failure() : success();
} }
@ -293,9 +293,16 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// Get delays stored in map. // Get delays stored in map.
std::vector<uint64_t> delays(forStmt->getStatements().size()); std::vector<uint64_t> delays(forStmt->getStatements().size());
unsigned s = 0; unsigned s = 0;
for (const auto &stmt : *forStmt) { for (auto &stmt : *forStmt) {
assert(stmtDelayMap.find(&stmt) != stmtDelayMap.end()); assert(stmtDelayMap.find(&stmt) != stmtDelayMap.end());
delays[s++] = stmtDelayMap[&stmt]; delays[s++] = stmtDelayMap[&stmt];
LLVM_DEBUG(
// Tagging statements with delays for debugging purposes.
if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
MLFuncBuilder b(opStmt);
opStmt->setAttr(b.getIdentifier("delay"),
b.getIntegerAttr(delays[s - 1]));
});
} }
if (!isStmtwiseShiftValid(*forStmt, delays)) { if (!isStmtwiseShiftValid(*forStmt, delays)) {

View File

@ -24,6 +24,7 @@
#include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h" #include "mlir/IR/Module.h"
#include "mlir/IR/StmtVisitor.h" #include "mlir/IR/StmtVisitor.h"
@ -47,13 +48,15 @@ static bool isMemRefDereferencingOp(const Operation &op) {
/// old memref's indices to the new memref using the supplied affine map /// old memref's indices to the new memref using the supplied affine map
/// and adding any additional indices. The new memref could be of a different /// and adding any additional indices. The new memref could be of a different
/// shape or rank, but of the same elemental type. Additional indices are added /// shape or rank, but of the same elemental type. Additional indices are added
/// at the start for now. /// at the start. An optional argument 'domOpFilter' restricts the
/// replacement to only those operations that are dominated by the former.
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily // TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
// extended to add additional indices at any position. // extended to add additional indices at any position.
bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
MLValue *newMemRef, MLValue *newMemRef,
ArrayRef<MLValue *> extraIndices, ArrayRef<MLValue *> extraIndices,
AffineMap indexRemap) { AffineMap indexRemap,
const Statement *domStmtFilter) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode (void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank(); unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
@ -82,6 +85,11 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) { for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
StmtOperand &use = *(it++); StmtOperand &use = *(it++);
auto *opStmt = cast<OperationStmt>(use.getOwner()); auto *opStmt = cast<OperationStmt>(use.getOwner());
// Skip this use if it's not dominated by domStmtFilter.
if (domStmtFilter && !dominates(*domStmtFilter, *opStmt))
continue;
assert(isMemRefDereferencingOp(*opStmt) && assert(isMemRefDereferencingOp(*opStmt) &&
"memref deferencing op expected"); "memref deferencing op expected");
@ -172,7 +180,7 @@ OperationStmt *
mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
ArrayRef<MLValue *> operands, ArrayRef<MLValue *> operands,
ArrayRef<OperationStmt *> affineApplyOps, ArrayRef<OperationStmt *> affineApplyOps,
SmallVectorImpl<SSAValue *> &results) { SmallVectorImpl<SSAValue *> *results) {
// Create identity map with same number of dimensions as number of operands. // Create identity map with same number of dimensions as number of operands.
auto map = builder->getMultiDimIdentityMap(operands.size()); auto map = builder->getMultiDimIdentityMap(operands.size());
// Initialize AffineValueMap with identity map. // Initialize AffineValueMap with identity map.
@ -194,9 +202,9 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
// Create new AffineApplyOp based on 'valueMap'. // Create new AffineApplyOp based on 'valueMap'.
auto affineApplyOp = auto affineApplyOp =
builder->create<AffineApplyOp>(loc, valueMap.getAffineMap(), outOperands); builder->create<AffineApplyOp>(loc, valueMap.getAffineMap(), outOperands);
results.resize(operands.size()); results->resize(operands.size());
for (unsigned i = 0, e = operands.size(); i < e; ++i) { for (unsigned i = 0, e = operands.size(); i < e; ++i) {
results[i] = affineApplyOp->getResult(i); (*results)[i] = affineApplyOp->getResult(i);
} }
return cast<OperationStmt>(affineApplyOp->getOperation()); return cast<OperationStmt>(affineApplyOp->getOperation());
} }
@ -247,8 +255,8 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
if (affineApplyOps.empty()) if (affineApplyOps.empty())
return nullptr; return nullptr;
// Check if all uses of the affine apply op's lie in this op stmt // Check if all uses of the affine apply op's lie only in this op stmt, in
// itself, in which case there would be nothing to do. // which case there would be nothing to do.
bool localized = true; bool localized = true;
for (auto *op : affineApplyOps) { for (auto *op : affineApplyOps) {
for (auto *result : op->getResults()) { for (auto *result : op->getResults()) {
@ -266,7 +274,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
FuncBuilder builder(opStmt); FuncBuilder builder(opStmt);
SmallVector<SSAValue *, 4> results; SmallVector<SSAValue *, 4> results;
auto *affineApplyStmt = createComposedAffineApplyOp( auto *affineApplyStmt = createComposedAffineApplyOp(
&builder, opStmt->getLoc(), subOperands, affineApplyOps, results); &builder, opStmt->getLoc(), subOperands, affineApplyOps, &results);
assert(results.size() == subOperands.size() && assert(results.size() == subOperands.size() &&
"number of results should be the same as the number of subOperands"); "number of results should be the same as the number of subOperands");

View File

@ -1,42 +1,155 @@
// RUN: mlir-opt %s -dma-generate | FileCheck %s // RUN: mlir-opt %s -dma-generate -canonicalize | FileCheck %s
// Index of the buffer for the second DMA is remapped. // Index of the buffer for the second DMA is remapped.
// CHECK-DAG: [[MAP:#map[0-9]+]] = (d0) -> (d0 - 256) // CHECK-DAG: [[MAP:#map[0-9]+]] = (d0) -> (d0 - 256)
// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0 * 16 + d1)
// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0, d1)
// CHECK-LABEL: mlfunc @loop_tiling() { // CHECK-LABEL: mlfunc @loop_nest_1d() {
mlfunc @loop_tiling() { mlfunc @loop_nest_1d() {
%A = alloc() : memref<256 x f32> %A = alloc() : memref<256 x f32>
%B = alloc() : memref<512 x f32> %B = alloc() : memref<512 x f32>
%F = alloc() : memref<128 x f32, 1> %F = alloc() : memref<256 x f32, 1>
// First DMA buffer. // First DMA buffer.
// CHECK: %3 = alloc() : memref<256xf32, 1> // CHECK: %3 = alloc() : memref<256xf32, 1>
// Tag for first DMA. // Tag for first DMA.
// CHECK: %4 = alloc() : memref<1xi32> // CHECK: %4 = alloc() : memref<1xi32>
// First DMA transfer. // First DMA transfer.
// CHECK: dma_start %3[%5], %3[%c0], %c256, %4[%c0] : memref<256xf32, 1>, memref<256xf32, 1>, memref<1xi32> // CHECK: dma_start %0[%c0], %3[%c0], %c256, %4[%c0] : memref<256xf32>, memref<256xf32, 1>, memref<1xi32>
// CHECK: dma_wait %4[%c0], %c256 : memref<1xi32> // CHECK: dma_wait %4[%c0], %c256 : memref<1xi32>
// Second DMA buffer. // Second DMA buffer.
// CHECK: %6 = alloc() : memref<256xf32, 1> // CHECK: %5 = alloc() : memref<256xf32, 1>
// Tag for second DMA. // Tag for second DMA.
// CHECK: %7 = alloc() : memref<1xi32> // CHECK: %6 = alloc() : memref<1xi32>
// Second DMA transfer. // Second DMA transfer.
// CHECK: dma_start %6[%8], %6[%c0_1], %c256_3, %7[%c0_1] : memref<256xf32, 1>, memref<256xf32, 1>, memref<1xi32> // CHECK: dma_start %1[%c256], %5[%c0], %c256, %6[%c0] : memref<512xf32>, memref<256xf32, 1>, memref<1xi32>
// CHECK-NEXT: dma_wait %7[%c0_1], %c256_3 : memref<1xi32> // CHECK-NEXT: dma_wait %6[%c0], %c256 : memref<1xi32>
// CHECK: for %i0 = 0 to 256 { // CHECK: for %i0 = 0 to 256 {
// CHECK: %7 = affine_apply #map{{[0-9]+}}(%i0)
// CHECK-NEXT: %8 = load %3[%7] : memref<256xf32, 1>
// CHECK: %9 = affine_apply #map{{[0-9]+}}(%i0) // CHECK: %9 = affine_apply #map{{[0-9]+}}(%i0)
// CHECK-NEXT: %10 = load %3[%9] : memref<256xf32, 1> // CHECK: %10 = affine_apply [[MAP]](%9)
// CHECK: %11 = affine_apply #map{{[0-9]+}}(%i0) // CHECK-NEXT: %11 = load %5[%10] : memref<256xf32, 1>
// CHECK: %12 = affine_apply [[MAP]](%11)
// CHECK-NEXT: %13 = load %6[%12] : memref<256xf32, 1>
// Already in faster memory space. // Already in faster memory space.
// CHECK: %14 = load %2[%i0] : memref<128xf32, 1> // CHECK: %12 = load %2[%i0] : memref<256xf32, 1>
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: return // CHECK-NEXT: return
for %i = 0 to 256 { for %i = 0 to 256 {
load %A[%i] : memref<256 x f32> load %A[%i] : memref<256 x f32>
%idx = affine_apply (d0) -> (d0 + 256)(%i) %idx = affine_apply (d0) -> (d0 + 256)(%i)
load %B[%idx] : memref<512 x f32> load %B[%idx] : memref<512 x f32>
load %F[%i] : memref<128 x f32, 1> load %F[%i] : memref<256 x f32, 1>
}
return
}
// CHECK-LABEL: mlfunc @loop_nest_high_d
// CHECK: %c16384 = constant 16384 : index
// CHECK-NEXT: %0 = alloc() : memref<512x32xf32, 1>
// CHECK-NEXT: %1 = alloc() : memref<1xi32>
// INCOMING DMA for B
// CHECK-NEXT: dma_start %arg1[%c0, %c0], %0[%c0, %c0], %c16384, %1[%c0] : memref<512x32xf32>, memref<512x32xf32, 1>, memref<1xi32>
// CHECK-NEXT: dma_wait %1[%c0], %c16384 : memref<1xi32>
// CHECK-NEXT: %2 = alloc() : memref<512x32xf32, 1>
// CHECK-NEXT: %3 = alloc() : memref<1xi32>
// INCOMING DMA for A.
// CHECK-NEXT: dma_start %arg0[%c0, %c0], %2[%c0, %c0], %c16384, %3[%c0] : memref<512x32xf32>, memref<512x32xf32, 1>, memref<1xi32>
// CHECK-NEXT: dma_wait %3[%c0], %c16384 : memref<1xi32>
// CHECK-NEXT: %4 = alloc() : memref<512x32xf32, 1>
// CHECK-NEXT: %5 = alloc() : memref<1xi32>
// INCOMING DMA for C.
// CHECK-NEXT: dma_start %arg2[%c0, %c0], %4[%c0, %c0], %c16384, %5[%c0] : memref<512x32xf32>, memref<512x32xf32, 1>, memref<1xi32>
// CHECK-NEXT: dma_wait %5[%c0], %c16384 : memref<1xi32>
// CHECK-NEXT: %6 = alloc() : memref<1xi32>
// CHECK-NEXT: for %i0 = 0 to 32 {
// CHECK-NEXT: for %i1 = 0 to 32 {
// CHECK-NEXT: for %i2 = 0 to 32 {
// CHECK-NEXT: for %i3 = 0 to 16 {
// CHECK-NEXT: %7 = affine_apply #map{{[0-9]+}}(%i1, %i3)
// CHECK-NEXT: %8 = affine_apply #map{{[0-9]+}}(%7, %i0)
// CHECK-NEXT: %9 = load %0[%8#0, %8#1] : memref<512x32xf32, 1>
// CHECK-NEXT: "foo"(%9) : (f32) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: for %i4 = 0 to 16 {
// CHECK-NEXT: %10 = affine_apply #map{{[0-9]+}}(%i2, %i4)
// CHECK-NEXT: %11 = affine_apply #map{{[0-9]+}}(%10, %i1)
// CHECK-NEXT: %12 = load %2[%11#0, %11#1] : memref<512x32xf32, 1>
// CHECK-NEXT: "bar"(%12) {mxu_id: 0} : (f32) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: for %i5 = 0 to 16 {
// CHECK-NEXT: %13 = "abc_compute"() : () -> f32
// CHECK-NEXT: %14 = affine_apply #map{{[0-9]+}}(%i2, %i5)
// CHECK-NEXT: %15 = affine_apply #map{{[0-9]+}}(%14, %i0)
// CHECK-NEXT: %16 = load %4[%15#0, %15#1] : memref<512x32xf32, 1>
// CHECK-NEXT: %17 = "addf32"(%13, %16) : (f32, f32) -> f32
// CHECK-NEXT: %18 = affine_apply #map{{[0-9]+}}(%14, %i0)
// CHECK-NEXT: store %17, %4[%18#0, %18#1] : memref<512x32xf32, 1>
// CHECK-NEXT: }
// CHECK-NEXT: "foobar"() : () -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// OUTGOING DMA for C.
// CHECK-NEXT: dma_start %4[%c0, %c0], %arg2[%c0, %c0], %c16384, %6[%c0] : memref<512x32xf32, 1>, memref<512x32xf32>, memref<1xi32>
// CHECK-NEXT: dma_wait %6[%c0], %c16384 : memref<1xi32>
// CHECK-NEXT: return
// CHECK-NEXT:}
mlfunc @loop_nest_high_d(%A: memref<512 x 32 x f32>,
%B: memref<512 x 32 x f32>, %C: memref<512 x 32 x f32>) {
// DMAs will be performed at this level (jT is the first loop without a stride).
// A and B are read, while C is both read and written. A total of three new buffers
// are allocated and existing load's/store's are replaced by accesses to those buffers.
for %jT = 0 to 32 {
for %kT = 0 to 32 {
for %iT = 0 to 32 {
for %kk = 0 to 16 { // k intratile
%k = affine_apply (d0, d1) -> (16*d0 + d1) (%kT, %kk)
%v0 = load %B[%k, %jT] : memref<512 x 32 x f32>
"foo"(%v0) : (f32) -> ()
}
for %ii = 0 to 16 { // i intratile.
%i = affine_apply (d0, d1) -> (16*d0 + d1)(%iT, %ii)
%v1 = load %A[%i, %kT] : memref<512 x 32 x f32>
"bar"(%v1) {mxu_id: 0} : (f32) -> ()
}
for %ii_ = 0 to 16 { // i intratile.
%v2 = "abc_compute"() : () -> f32
%i_ = affine_apply (d0, d1) -> (16*d0 + d1)(%iT, %ii_)
%v3 = load %C[%i_, %jT] : memref<512 x 32 x f32>
%v4 = "addf32"(%v2, %v3) : (f32, f32) -> (f32)
store %v4, %C[%i_, %jT] : memref<512 x 32 x f32>
}
"foobar"() : () -> ()
}
}
}
return
}
// A loop nest with a modulo 2 access.
//
// CHECK-LABEL: mlfunc @loop_nest_modulo() {
// CHECK: %0 = alloc() : memref<256x8xf32>
// CHECK-NEXT: for %i0 = 0 to 32 step 4 {
// CHECK-NEXT: %1 = alloc() : memref<32x2xf32, 1>
// CHECK-NEXT: %2 = alloc() : memref<1xi32>
// CHECK-NEXT: dma_start %0[%c0, %c0], %1[%c0, %c0], %c64, %2[%c0] : memref<256x8xf32>, memref<32x2xf32, 1>, memref<1xi32>
// CHECK-NEXT: dma_wait %2[%c0], %c64 : memref<1xi32>
// CHECK-NEXT: for %i1 = 0 to 8 {
// ...
// ...
// CHECK: }
// CHECK-NEXT: }
// CHECK-NEXT: return
mlfunc @loop_nest_modulo() {
%A = alloc() : memref<256 x 8 x f32>
for %i = 0 to 32 step 4 {
// DMAs will be performed at this level (%j is the first unit stride loop)
for %j = 0 to 8 {
%idx = affine_apply (d0) -> (d0 mod 2) (%j)
// A buffer of size 32 x 2 will be allocated (original buffer was 256 x 8).
%v = load %A[%i, %idx] : memref<256 x 8 x f32>
}
} }
return return
} }

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -loop-tile | FileCheck %s // RUN: mlir-opt %s -loop-tile -tile-size=32 | FileCheck %s
// CHECK: #map0 = (d0) -> (d0 + 32) // CHECK: #map0 = (d0) -> (d0 + 32)
// CHECK: #map1 = (d0) -> (d0 + 32, 50) // CHECK: #map1 = (d0) -> (d0 + 32, 50)