forked from OSchip/llvm-project
Adds a dependence check to test whether two accesses to the same memref access the same element.
- Builds access functions and iterations domains for each access. - Builds dependence polyhedron constraint system which has equality constraints for equated access functions and inequality constraints for iteration domain loop bounds. - Runs elimination on the dependence polyhedron to test if no dependence exists between the accesses. - Adds a trivial LoopFusion transformation pass with a simple test policy to test dependence between accesses to the same memref in adjacent loops. - The LoopFusion pass will be extended in subsequent CLs. PiperOrigin-RevId: 219630898
This commit is contained in:
parent
21638dcda9
commit
f28e4df666
|
@ -31,6 +31,9 @@
|
|||
namespace mlir {
|
||||
|
||||
class AffineExpr;
|
||||
class AffineValueMap;
|
||||
class ForStmt;
|
||||
class MLIRContext;
|
||||
class MLValue;
|
||||
class OperationStmt;
|
||||
|
||||
|
@ -48,6 +51,10 @@ void getReachableAffineApplyOps(
|
|||
llvm::ArrayRef<MLValue *> operands,
|
||||
llvm::SmallVectorImpl<OperationStmt *> &affineApplyOps);
|
||||
|
||||
/// Forward substitutes into 'valueMap' all AffineApplyOps reachable from the
|
||||
/// operands of 'valueMap'.
|
||||
void forwardSubstituteReachableOps(AffineValueMap *valueMap);
|
||||
|
||||
/// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false
|
||||
/// if 'expr' was unable to be flattened (i.e. because it was not pure affine,
|
||||
/// or because it contained mod's and div's that could not be eliminated
|
||||
|
@ -56,6 +63,22 @@ bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
|
|||
unsigned numSymbols,
|
||||
llvm::SmallVectorImpl<int64_t> *flattenedExpr);
|
||||
|
||||
/// Checks whether two accesses to the same memref access the same element.
|
||||
/// Each access is specified using the MemRefAccess structure, which contains
|
||||
/// the operation statement, indices and memref associated with the access.
|
||||
/// Returns 'false' if it can be determined conclusively that the accesses do
|
||||
/// not access the same memref element. Returns 'true' otherwise.
|
||||
struct MemRefAccess {
|
||||
const MLValue *memref;
|
||||
const OperationStmt *opStmt;
|
||||
llvm::SmallVector<MLValue *, 4> indices;
|
||||
// Populates 'accessMap' with composition of AffineApplyOps reachable from
|
||||
// 'indices'.
|
||||
void getAccessMap(AffineValueMap *accessMap) const;
|
||||
};
|
||||
bool checkMemrefAccessDependence(const MemRefAccess &srcAccess,
|
||||
const MemRefAccess &dstAccess);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H
|
||||
|
|
|
@ -42,6 +42,7 @@ class HyperRectangularSet;
|
|||
/// A mutable affine map. Its affine expressions are however unique.
|
||||
struct MutableAffineMap {
|
||||
public:
|
||||
MutableAffineMap() {}
|
||||
MutableAffineMap(AffineMap map);
|
||||
|
||||
AffineExpr getResult(unsigned idx) const { return results[idx]; }
|
||||
|
@ -56,6 +57,9 @@ public:
|
|||
/// Returns true if the idx'th result expression is a multiple of factor.
|
||||
bool isMultipleOf(unsigned idx, int64_t factor) const;
|
||||
|
||||
/// Resets this MutableAffineMap with 'map'.
|
||||
void reset(AffineMap map);
|
||||
|
||||
/// Simplify the (result) expressions in this map using analysis (used by
|
||||
//-simplify-affine-expr pass).
|
||||
void simplify();
|
||||
|
@ -120,6 +124,9 @@ private:
|
|||
// TODO(bondhugula): Some of these classes could go into separate files.
|
||||
class AffineValueMap {
|
||||
public:
|
||||
// Creates an empty AffineValueMap (users should call 'reset' to reset map
|
||||
// and operands).
|
||||
AffineValueMap() {}
|
||||
AffineValueMap(const AffineApplyOp &op);
|
||||
AffineValueMap(const AffineBound &bound);
|
||||
AffineValueMap(AffineMap map);
|
||||
|
@ -145,6 +152,8 @@ public:
|
|||
// can be used to amortize the cost of simplification over multiple fwd
|
||||
// substitutions).
|
||||
|
||||
// Resets this AffineValueMap with 'map' and 'operands'.
|
||||
void reset(AffineMap map, ArrayRef<MLValue *> operands);
|
||||
/// Return true if the idx^th result can be proved to be a multiple of
|
||||
/// 'factor', false otherwise.
|
||||
inline bool isMultipleOf(unsigned idx, int64_t factor) const;
|
||||
|
|
|
@ -56,6 +56,10 @@ public:
|
|||
/// Returns a single constant result affine map.
|
||||
static AffineMap getConstantMap(int64_t val, MLIRContext *context);
|
||||
|
||||
/// Returns an AffineMap with 'numDims' identity result dim exprs.
|
||||
static AffineMap getMultiDimIdentityMap(unsigned numDims,
|
||||
MLIRContext *context);
|
||||
|
||||
MLIRContext *getContext() const;
|
||||
|
||||
explicit operator bool() { return map; }
|
||||
|
|
|
@ -52,6 +52,9 @@ FunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1);
|
|||
/// Creates an simplification pass for affine structures.
|
||||
FunctionPass *createSimplifyAffineStructuresPass();
|
||||
|
||||
/// Creates a loop fusion pass which fuses loops in MLFunctions.
|
||||
FunctionPass *createLoopFusionPass();
|
||||
|
||||
/// Creates a pass to pipeline explicit movement of data across levels of the
|
||||
/// memory hierarchy.
|
||||
FunctionPass *createPipelineDataTransferPass();
|
||||
|
|
|
@ -25,6 +25,8 @@
|
|||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -370,3 +372,475 @@ void mlir::getReachableAffineApplyOps(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Forward substitutes into 'valueMap' all AffineApplyOps reachable from the
|
||||
// operands of 'valueMap'.
|
||||
void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) {
|
||||
// Gather AffineApplyOps reachable from 'indices'.
|
||||
SmallVector<OperationStmt *, 4> affineApplyOps;
|
||||
getReachableAffineApplyOps(valueMap->getOperands(), affineApplyOps);
|
||||
// Compose AffineApplyOps in 'affineApplyOps'.
|
||||
for (auto *opStmt : affineApplyOps) {
|
||||
assert(opStmt->isa<AffineApplyOp>());
|
||||
auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>();
|
||||
// Forward substitute 'affineApplyOp' into 'valueMap'.
|
||||
valueMap->forwardSubstitute(*affineApplyOp);
|
||||
}
|
||||
}
|
||||
|
||||
// Adds loop upper and lower bound inequalities to 'domain' for each ForStmt
|
||||
// value in 'values'. Requires that the first 'numDims' MLValues in 'values'
|
||||
// are ForStmts. Returns true if lower/upper bound inequalities were
|
||||
// successfully added, returns false otherwise.
|
||||
// TODO(andydavis) Get operands for loop bounds so we can add domain
|
||||
// constraints for non-constant loop bounds.
|
||||
static bool addLoopBoundInequalities(unsigned numDims,
|
||||
ArrayRef<const MLValue *> values,
|
||||
FlatAffineConstraints *domain) {
|
||||
assert(values.size() >= numDims);
|
||||
unsigned numIds = values.size();
|
||||
// Add InEqualties for loop bounds.
|
||||
SmallVector<int64_t, 4> ineq;
|
||||
ineq.resize(numIds + 1);
|
||||
for (unsigned i = 0; i < numDims; ++i) {
|
||||
const ForStmt *forStmt = dyn_cast<ForStmt>(values[i]);
|
||||
if (!forStmt || !forStmt->hasConstantBounds())
|
||||
return false;
|
||||
// Zero fill
|
||||
std::fill(ineq.begin(), ineq.end(), 0);
|
||||
// TODO(andydavis, bondhugula) Add methods for addUpper/LowerBound.
|
||||
// Add inequality for lower bound.
|
||||
ineq[i] = 1;
|
||||
ineq[numIds] = -forStmt->getConstantLowerBound();
|
||||
domain->addInequality(ineq);
|
||||
// Add inequality for upper bound.
|
||||
ineq[i] = -1;
|
||||
ineq[numIds] = forStmt->getConstantUpperBound();
|
||||
domain->addInequality(ineq);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// IterationDomainContext encapsulates the state required to represent
|
||||
// the iteration domain of an OperationStmt.
|
||||
struct IterationDomainContext {
|
||||
// Set of inequality constraint pairs, where each pair represents the
|
||||
// upper/lower bounds of a ForStmt in the iteration domain.
|
||||
FlatAffineConstraints domain;
|
||||
// The number of dimension identifiers in 'values'.
|
||||
unsigned numDims;
|
||||
// The list of MLValues in this iteration domain, with MLValues in
|
||||
// [0, numDims) representing dimension identifiers, and MLValues in
|
||||
// [numDims, values.size()) representing symbol identifiers.
|
||||
SmallVector<const MLValue *, 4> values;
|
||||
IterationDomainContext() : numDims(0) {}
|
||||
unsigned getNumDims() { return numDims; }
|
||||
unsigned getNumSymbols() { return values.size() - numDims; }
|
||||
};
|
||||
|
||||
// Computes the iteration domain for 'opStmt' and populates 'ctx', which
|
||||
// encapsulates the following state for each ForStmt in 'opStmt's iteration
|
||||
// domain:
|
||||
// *) adds inequality constraints representing the ForStmt upper/lower bounds.
|
||||
// *) adds MLValues and symbols for the ForStmt and its operands to a list.
|
||||
// TODO(andydavis) Add support for IfStmts in iteration domain.
|
||||
// TODO(andydavis) Handle non-constant loop bounds by composing affine maps
|
||||
// for each ForStmt loop bound and adding de-duped ids/symbols to iteration
|
||||
// domain context.
|
||||
// TODO(andydavis) Capture the context of the symbols. For example, check
|
||||
// if a symbol is the result of a constant operation, and set the symbol to
|
||||
// that value in FlatAffineConstraints (using setIdToConstant).
|
||||
static bool getIterationDomainContext(const OperationStmt *opStmt,
|
||||
IterationDomainContext *ctx) {
|
||||
// Walk up tree storing parent statements in 'loops'.
|
||||
// TODO(andydavis) Extend this to gather enclosing IfStmts and consider
|
||||
// factoring it out into a utility function.
|
||||
SmallVector<const ForStmt *, 4> loops;
|
||||
const auto *currStmt = opStmt->getParentStmt();
|
||||
while (currStmt != nullptr) {
|
||||
if (isa<IfStmt>(currStmt))
|
||||
return false;
|
||||
assert(isa<ForStmt>(currStmt));
|
||||
auto *forStmt = dyn_cast<ForStmt>(currStmt);
|
||||
loops.push_back(forStmt);
|
||||
currStmt = currStmt->getParentStmt();
|
||||
}
|
||||
// Iterate through 'loops' from outer-most loop to inner-most loop.
|
||||
// Populate 'values'.
|
||||
ctx->values.reserve(loops.size());
|
||||
for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
|
||||
auto *forStmt = loops[i];
|
||||
// TODO(andydavis) Compose affine maps into lower/upper bounds of 'forStmt'
|
||||
// and add de-duped symbols to ctx.symbols.
|
||||
if (!forStmt->hasConstantBounds())
|
||||
return false;
|
||||
ctx->values.push_back(forStmt);
|
||||
ctx->numDims++;
|
||||
}
|
||||
// Resize flat affine constraint system based on num dims symbols found.
|
||||
unsigned numDims = ctx->getNumDims();
|
||||
unsigned numSymbols = ctx->getNumSymbols();
|
||||
ctx->domain.reset(/*newNumReservedInequalities=*/2 * numDims,
|
||||
/*newNumReservedEqualities=*/0,
|
||||
/*newNumReservedCols=*/numDims + numSymbols + 1, numDims,
|
||||
numSymbols);
|
||||
return addLoopBoundInequalities(numDims, ctx->values, &ctx->domain);
|
||||
}
|
||||
|
||||
// Builds a map from MLValue to identifier position in a new merged identifier
|
||||
// list, which is the result of merging dim/symbol lists from src/dst
|
||||
// iteration domains. The format of the new merged list is as follows:
|
||||
//
|
||||
// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers]
|
||||
//
|
||||
// This method populates 'srcDimPosMap' and 'dstDimPosMap' with mappings from
|
||||
// operand MLValues in 'srcAccessMap'/'dstAccessMap' to the position of these
|
||||
// values in the merged list.
|
||||
// In addition, this method populates 'symbolPosMap' with mappings from
|
||||
// operand MLValues in both 'srcIterationDomainContext' and
|
||||
// 'dstIterationDomainContext' to position of these values in the merged list.
|
||||
static void buildDimAndSymbolPositionMaps(
|
||||
const IterationDomainContext &srcIterationDomainContext,
|
||||
const IterationDomainContext &dstIterationDomainContext,
|
||||
const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap,
|
||||
DenseMap<const MLValue *, unsigned> *srcDimPosMap,
|
||||
DenseMap<const MLValue *, unsigned> *dstDimPosMap,
|
||||
DenseMap<const MLValue *, unsigned> *symbolPosMap) {
|
||||
unsigned pos = 0;
|
||||
|
||||
auto updatePosMap = [&](DenseMap<const MLValue *, unsigned> *posMap,
|
||||
ArrayRef<const MLValue *> values, unsigned start,
|
||||
unsigned limit) {
|
||||
for (unsigned i = start; i < limit; ++i) {
|
||||
auto *value = values[i];
|
||||
auto it = posMap->find(value);
|
||||
if (it == posMap->end()) {
|
||||
(*posMap)[value] = pos++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
AffineMap srcMap = srcAccessMap.getAffineMap();
|
||||
AffineMap dstMap = dstAccessMap.getAffineMap();
|
||||
|
||||
// Update position map with src dimension identifiers from iteration domain
|
||||
// and access function.
|
||||
updatePosMap(srcDimPosMap, srcIterationDomainContext.values, 0,
|
||||
srcIterationDomainContext.numDims);
|
||||
|
||||
// Update position map with dst dimension identifiers from iteration domain
|
||||
// and access function.
|
||||
updatePosMap(dstDimPosMap, dstIterationDomainContext.values, 0,
|
||||
dstIterationDomainContext.numDims);
|
||||
|
||||
// Update position map with src symbol identifiers from iteration domain
|
||||
// and access function.
|
||||
updatePosMap(symbolPosMap, srcIterationDomainContext.values,
|
||||
dstIterationDomainContext.numDims,
|
||||
srcIterationDomainContext.values.size());
|
||||
updatePosMap(symbolPosMap, srcAccessMap.getOperands(), srcMap.getNumDims(),
|
||||
srcMap.getNumDims() + srcMap.getNumSymbols());
|
||||
|
||||
// Update position map with dst symbol identifiers from iteration domain
|
||||
// and access function.
|
||||
updatePosMap(symbolPosMap, dstIterationDomainContext.values,
|
||||
dstIterationDomainContext.numDims,
|
||||
dstIterationDomainContext.values.size());
|
||||
updatePosMap(symbolPosMap, dstAccessMap.getOperands(), dstMap.getNumDims(),
|
||||
dstMap.getNumDims() + dstMap.getNumSymbols());
|
||||
}
|
||||
|
||||
static unsigned getPos(const DenseMap<const MLValue *, unsigned> &posMap,
|
||||
const MLValue *value) {
|
||||
auto it = posMap.find(value);
|
||||
assert(it != posMap.end());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Adds iteration domain constraints from 'ctx.domain' into 'outputFac'.
|
||||
// Uses 'dimPosMap' to map from dim operand value in 'ctx.values', to dim
|
||||
// position in 'outputFac'.
|
||||
// Uses 'symbolPosMap' to map from symbol operand value in 'ctx.values', to
|
||||
// symbol position in 'outputFac'.
|
||||
static void
|
||||
addDomainConstraints(const IterationDomainContext &ctx,
|
||||
const DenseMap<const MLValue *, unsigned> &dimPosMap,
|
||||
const DenseMap<const MLValue *, unsigned> &symbolPosMap,
|
||||
FlatAffineConstraints *outputFac) {
|
||||
unsigned inputNumIneq = ctx.domain.getNumInequalities();
|
||||
unsigned inputNumDims = ctx.domain.getNumDimIds();
|
||||
unsigned inputNumSymbols = ctx.domain.getNumSymbolIds();
|
||||
unsigned inputNumIds = inputNumDims + inputNumSymbols;
|
||||
|
||||
unsigned outputNumDims = outputFac->getNumDimIds();
|
||||
unsigned outputNumSymbols = outputFac->getNumSymbolIds();
|
||||
unsigned outputNumIds = outputNumDims + outputNumSymbols;
|
||||
|
||||
SmallVector<int64_t, 4> eq;
|
||||
eq.resize(outputNumIds + 1);
|
||||
for (unsigned i = 0; i < inputNumIneq; ++i) {
|
||||
// Zero fill.
|
||||
std::fill(eq.begin(), eq.end(), 0);
|
||||
// Add dim identifiers.
|
||||
for (unsigned j = 0; j < inputNumDims; ++j)
|
||||
eq[getPos(dimPosMap, ctx.values[j])] = ctx.domain.atIneq(i, j);
|
||||
// Add symbol identifiers.
|
||||
for (unsigned j = inputNumDims; j < inputNumIds; ++j) {
|
||||
eq[getPos(symbolPosMap, ctx.values[j])] = ctx.domain.atIneq(i, j);
|
||||
}
|
||||
// Add constant term.
|
||||
eq[outputNumIds] = ctx.domain.atIneq(i, inputNumIds);
|
||||
// Add inequality constraint.
|
||||
outputFac->addInequality(eq);
|
||||
}
|
||||
}
|
||||
|
||||
// Adds equality constraints that equate src and dst access functions
|
||||
// represented by 'srcAccessMap' and 'dstAccessMap' for each result.
|
||||
// Requires that 'srcAccessMap' and 'dstAccessMap' have the same results count.
|
||||
// For example, given the following two accesses functions to a 2D memref:
|
||||
//
|
||||
// Source access function:
|
||||
// (a0 * d0 + a1 * s0 + a2, b0 * d0 + b1 * s0 + b2)
|
||||
//
|
||||
// Destination acceses function:
|
||||
// (c0 * d0 + c1 * s0 + c2, f0 * d0 + f1 * s0 + f2)
|
||||
//
|
||||
// This method constructs the following equality constraints in 'outputFac',
|
||||
// by equating the access functions for each result (i.e. each memref dim).
|
||||
// (notice that 'd0' for the destination access function is mapped into 'd0'
|
||||
// in the equality constraint):
|
||||
//
|
||||
// d0 d1 s0 c
|
||||
// -- -- -- --
|
||||
// a0 -c0 (a1 - c1) (a1 - c2) = 0
|
||||
// b0 -f0 (b1 - f1) (b1 - f2) = 0
|
||||
//
|
||||
bool addMemRefAccessConstraints(
|
||||
const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap,
|
||||
const DenseMap<const MLValue *, unsigned> &srcDimPosMap,
|
||||
const DenseMap<const MLValue *, unsigned> &dstDimPosMap,
|
||||
const DenseMap<const MLValue *, unsigned> &symbolPosMap,
|
||||
FlatAffineConstraints *outputFac) {
|
||||
AffineMap srcMap = srcAccessMap.getAffineMap();
|
||||
AffineMap dstMap = dstAccessMap.getAffineMap();
|
||||
assert(srcMap.getNumResults() == dstMap.getNumResults());
|
||||
unsigned numResults = srcMap.getNumResults();
|
||||
|
||||
unsigned srcNumDims = srcMap.getNumDims();
|
||||
unsigned srcNumSymbols = srcMap.getNumSymbols();
|
||||
unsigned srcNumIds = srcNumDims + srcNumSymbols;
|
||||
ArrayRef<MLValue *> srcOperands = srcAccessMap.getOperands();
|
||||
|
||||
unsigned dstNumDims = dstMap.getNumDims();
|
||||
unsigned dstNumSymbols = dstMap.getNumSymbols();
|
||||
unsigned dstNumIds = dstNumDims + dstNumSymbols;
|
||||
ArrayRef<MLValue *> dstOperands = dstAccessMap.getOperands();
|
||||
|
||||
unsigned outputNumDims = outputFac->getNumDimIds();
|
||||
unsigned outputNumSymbols = outputFac->getNumSymbolIds();
|
||||
unsigned outputNumIds = outputNumDims + outputNumSymbols;
|
||||
|
||||
SmallVector<int64_t, 4> eq;
|
||||
eq.resize(outputNumIds + 1);
|
||||
SmallVector<int64_t, 4> flattenedExpr;
|
||||
for (unsigned i = 0; i < numResults; ++i) {
|
||||
// Zero fill.
|
||||
std::fill(eq.begin(), eq.end(), 0);
|
||||
// Get flattened AffineExpr for result 'i' from src access function.
|
||||
auto srcExpr = srcMap.getResult(i);
|
||||
flattenedExpr.clear();
|
||||
if (!getFlattenedAffineExpr(srcExpr, srcNumDims, srcNumSymbols,
|
||||
&flattenedExpr))
|
||||
return false;
|
||||
// Add dim identifier coefficients from src access function.
|
||||
for (unsigned j = 0, e = srcNumDims; j < e; ++j)
|
||||
eq[getPos(srcDimPosMap, srcOperands[j])] = flattenedExpr[j];
|
||||
// Add symbol identifiers from src access function.
|
||||
for (unsigned j = srcNumDims; j < srcNumIds; ++j)
|
||||
eq[getPos(symbolPosMap, srcOperands[j])] = flattenedExpr[j];
|
||||
// Add constant term.
|
||||
eq[outputNumIds] = flattenedExpr[srcNumIds];
|
||||
|
||||
// Get flattened AffineExpr for result 'i' from dst access function.
|
||||
auto dstExpr = dstMap.getResult(i);
|
||||
flattenedExpr.clear();
|
||||
if (!getFlattenedAffineExpr(dstExpr, dstNumDims, dstNumSymbols,
|
||||
&flattenedExpr))
|
||||
return false;
|
||||
// Add dim identifier coefficients from dst access function.
|
||||
for (unsigned j = 0, e = dstNumDims; j < e; ++j)
|
||||
eq[getPos(dstDimPosMap, dstOperands[j])] = -flattenedExpr[j];
|
||||
// Add symbol identifiers from dst access function.
|
||||
for (unsigned j = dstNumDims; j < dstNumIds; ++j)
|
||||
eq[getPos(symbolPosMap, dstOperands[j])] -= flattenedExpr[j];
|
||||
// Add constant term.
|
||||
eq[outputNumIds] -= flattenedExpr[dstNumIds];
|
||||
// Add equality constraint.
|
||||
outputFac->addEquality(eq);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Populates 'accessMap' with composition of AffineApplyOps reachable from
|
||||
// indices of MemRefAccess.
|
||||
void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
|
||||
auto memrefType = memref->getType().cast<MemRefType>();
|
||||
// Create identity map with same number of dimensions as 'memrefType' rank.
|
||||
auto map = AffineMap::getMultiDimIdentityMap(memrefType.getRank(),
|
||||
memref->getType().getContext());
|
||||
// Reset 'accessMap' and 'map' and access 'indices'.
|
||||
accessMap->reset(map, indices);
|
||||
// Compose 'accessMap' with reachable AffineApplyOps.
|
||||
forwardSubstituteReachableOps(accessMap);
|
||||
}
|
||||
|
||||
// Builds a flat affine constraint system to check if there exists a dependence
|
||||
// between memref accesses 'srcAccess' and 'dstAccess'.
|
||||
// Returns 'false' if the accesses can be definitively shown not to access the
|
||||
// same element. Returns 'true' otherwise.
|
||||
//
|
||||
// The memref access dependence check is comprised of the following steps:
|
||||
// *) Compute access functions for each access. Access functions are computed
|
||||
// using AffineValueMaps initialized with the indices from an access, then
|
||||
// composed with AffineApplyOps reachable from operands of that access,
|
||||
// until operands of the AffineValueMap are loop IVs or symbols.
|
||||
// *) Build iteration domain constraints for each access. Iteration domain
|
||||
// constraints are pairs of inequality contraints representing the
|
||||
// upper/lower loop bounds for each ForStmt in the loop nest associated
|
||||
// with each access.
|
||||
// *) Build dimension and symbol position maps for each access, which map
|
||||
// MLValues from access functions and iteration domains to their position
|
||||
// in the merged constraint system build by this method.
|
||||
//
|
||||
// This method builds a constraint system with the following column format:
|
||||
//
|
||||
// [src-dim-identifiers, dst-dim-identifiers, symbols, constant]
|
||||
//
|
||||
// For example, given the following MLIR code with with "source" and
|
||||
// "destination" accesses to the same memref labled, and symbols %M, %N, %K:
|
||||
//
|
||||
// for %i0 = 0 to 100 {
|
||||
// for %i1 = 0 to 50 {
|
||||
// %a0 = affine_apply
|
||||
// (d0, d1) -> (d0 * 2 - d1 * 4 + s1, d1 * 3 - s0) (%i0, %i1)[%M, %N]
|
||||
// // Source memref access.
|
||||
// store %v0, %m[%a0#0, %a0#1] : memref<4x4xf32>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// for %i2 = 0 to 100 {
|
||||
// for %i3 = 0 to 50 {
|
||||
// %a1 = affine_apply
|
||||
// (d0, d1) -> (d0 * 7 + d1 * 9 - s1, d1 * 11 + s0) (%i2, %i3)[%K, %M]
|
||||
// // Destination memref access.
|
||||
// %v1 = load %m[%a1#0, %a1#1] : memref<4x4xf32>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// The access functions would be the following:
|
||||
//
|
||||
// src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M)
|
||||
// src: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K)
|
||||
//
|
||||
// The iteration domains for the src/dst accesses would be the following:
|
||||
//
|
||||
// src: 0 <= %i0 <= 100, 0 <= %i1 <= 50
|
||||
// dst: 0 <= %i2 <= 100, 0 <= %i3 <= 50
|
||||
//
|
||||
// The symbols by both accesses would be assigned to a canonical position order
|
||||
// which will be used in the dependence constraint system:
|
||||
//
|
||||
// symbol name: %M %N %K
|
||||
// symbol pos: 0 1 2
|
||||
//
|
||||
// Equality constraints are built by equating each result of src/destination
|
||||
// access functions. For this example, the folloing two equality constraints
|
||||
// will be added to the dependence constraint system:
|
||||
//
|
||||
// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
|
||||
// 2 -4 -7 -9 1 1 0 0 = 0
|
||||
// 0 3 0 -11 -1 0 1 0 = 0
|
||||
//
|
||||
// Inequality constraints from the iteration domain will be meged into
|
||||
// the dependence constraint system
|
||||
//
|
||||
// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
|
||||
// 1 0 0 0 0 0 0 0 >= 0
|
||||
// -1 0 0 0 0 0 0 100 >= 0
|
||||
// 0 1 0 0 0 0 0 0 >= 0
|
||||
// 0 -1 0 0 0 0 0 50 >= 0
|
||||
// 0 0 1 0 0 0 0 0 >= 0
|
||||
// 0 0 -1 0 0 0 0 100 >= 0
|
||||
// 0 0 0 1 0 0 0 0 >= 0
|
||||
// 0 0 0 -1 0 0 0 50 >= 0
|
||||
//
|
||||
//
|
||||
// TODO(andydavis) Support AffineExprs mod/floordiv/ceildiv.
|
||||
// TODO(andydavis) Add precedence order constraints for accesses that
|
||||
// share a common loop.
|
||||
bool mlir::checkMemrefAccessDependence(const MemRefAccess &srcAccess,
|
||||
const MemRefAccess &dstAccess) {
|
||||
// Return 'false' if these accesses do not acces the same memref.
|
||||
if (srcAccess.memref != dstAccess.memref)
|
||||
return false;
|
||||
|
||||
// Get composed access function for 'srcAccess'.
|
||||
AffineValueMap srcAccessMap;
|
||||
srcAccess.getAccessMap(&srcAccessMap);
|
||||
|
||||
// Get composed access function for 'dstAccess'.
|
||||
AffineValueMap dstAccessMap;
|
||||
dstAccess.getAccessMap(&dstAccessMap);
|
||||
|
||||
// Get iteration domain context for 'srcAccess'.
|
||||
IterationDomainContext srcIterationDomainContext;
|
||||
if (!getIterationDomainContext(srcAccess.opStmt, &srcIterationDomainContext))
|
||||
return false;
|
||||
|
||||
// Get iteration domain context for 'dstAccess'.
|
||||
IterationDomainContext dstIterationDomainContext;
|
||||
if (!getIterationDomainContext(dstAccess.opStmt, &dstIterationDomainContext))
|
||||
return false;
|
||||
|
||||
// Build dim and symbol position maps for each access from access operand
|
||||
// MLValue to position in merged contstraint system.
|
||||
DenseMap<const MLValue *, unsigned> srcDimPosMap;
|
||||
DenseMap<const MLValue *, unsigned> dstDimPosMap;
|
||||
DenseMap<const MLValue *, unsigned> symbolPosMap;
|
||||
buildDimAndSymbolPositionMaps(
|
||||
srcIterationDomainContext, dstIterationDomainContext, srcAccessMap,
|
||||
dstAccessMap, &srcDimPosMap, &dstDimPosMap, &symbolPosMap);
|
||||
|
||||
// TODO(andydavis) Add documentation.
|
||||
unsigned numIneq = srcIterationDomainContext.domain.getNumInequalities() +
|
||||
dstIterationDomainContext.domain.getNumInequalities();
|
||||
AffineMap srcMap = srcAccessMap.getAffineMap();
|
||||
assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults());
|
||||
unsigned numEq = srcMap.getNumResults();
|
||||
unsigned numDims = srcDimPosMap.size() + dstDimPosMap.size();
|
||||
unsigned numSymbols = symbolPosMap.size();
|
||||
unsigned numIds = numDims + numSymbols;
|
||||
unsigned numCols = numIds + 1;
|
||||
|
||||
// Create flat affine constraints reserving space for 'numEq' and 'numIneq'.
|
||||
// TODO(andydavis) better name.
|
||||
FlatAffineConstraints constraints(numIneq, numEq, numCols, numDims,
|
||||
numSymbols);
|
||||
// Create memref access constraint by equating src/dst access functions.
|
||||
// Note that this check is conservative, and will failure in the future
|
||||
// when local variables for mod/div exprs are supported.
|
||||
if (!addMemRefAccessConstraints(srcAccessMap, dstAccessMap, srcDimPosMap,
|
||||
dstDimPosMap, symbolPosMap, &constraints))
|
||||
return true;
|
||||
|
||||
// Add domain constraints for src access function.
|
||||
addDomainConstraints(srcIterationDomainContext, srcDimPosMap, symbolPosMap,
|
||||
&constraints);
|
||||
// Add equality constraints from 'dstConstraints'.
|
||||
addDomainConstraints(dstIterationDomainContext, dstDimPosMap, symbolPosMap,
|
||||
&constraints);
|
||||
bool isEmpty = constraints.isEmpty();
|
||||
// Return false if the solution space is empty.
|
||||
return !isEmpty;
|
||||
}
|
||||
|
|
|
@ -168,10 +168,9 @@ forwardSubstituteMutableAffineMap(const AffineMapCompositionUpdate &mapUpdate,
|
|||
map->setNumDims(mapUpdate.outputNumDims);
|
||||
map->setNumSymbols(mapUpdate.outputNumSymbols);
|
||||
}
|
||||
|
||||
MutableAffineMap::MutableAffineMap(AffineMap map)
|
||||
: numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
|
||||
// A map always has at leat 1 result by construction
|
||||
// A map always has at least 1 result by construction
|
||||
context(map.getResult(0).getContext()) {
|
||||
for (auto result : map.getResults())
|
||||
results.push_back(result);
|
||||
|
@ -179,6 +178,19 @@ MutableAffineMap::MutableAffineMap(AffineMap map)
|
|||
results.push_back(rangeSize);
|
||||
}
|
||||
|
||||
void MutableAffineMap::reset(AffineMap map) {
|
||||
results.clear();
|
||||
rangeSizes.clear();
|
||||
numDims = map.getNumDims();
|
||||
numSymbols = map.getNumSymbols();
|
||||
// A map always has at least 1 result by construction
|
||||
context = map.getResult(0).getContext();
|
||||
for (auto result : map.getResults())
|
||||
results.push_back(result);
|
||||
for (auto rangeSize : map.getRangeSizes())
|
||||
results.push_back(rangeSize);
|
||||
}
|
||||
|
||||
bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
|
||||
if (results[idx].isMultipleOf(factor))
|
||||
return true;
|
||||
|
@ -228,6 +240,15 @@ AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<MLValue *> operands)
|
|||
}
|
||||
}
|
||||
|
||||
void AffineValueMap::reset(AffineMap map, ArrayRef<MLValue *> operands) {
|
||||
this->operands.clear();
|
||||
this->results.clear();
|
||||
this->map.reset(map);
|
||||
for (MLValue *operand : operands) {
|
||||
this->operands.push_back(operand);
|
||||
}
|
||||
}
|
||||
|
||||
void AffineValueMap::forwardSubstitute(const AffineApplyOp &inputOp) {
|
||||
SmallVector<bool, 4> inputResultsToSubstitute(inputOp.getNumResults());
|
||||
for (unsigned i = 0, e = inputOp.getNumResults(); i < e; i++)
|
||||
|
|
|
@ -53,21 +53,6 @@ FunctionPass *mlir::createMemRefBoundCheckPass() {
|
|||
return new MemRefBoundCheck();
|
||||
}
|
||||
|
||||
// Forward substitutes into 'valueMap' all AffineApplyOps reachable from the
|
||||
// operands of 'valueMap'.
|
||||
static void forwardSubstituteReachableOps(AffineValueMap *valueMap) {
|
||||
// Gather AffineApplyOps reachable from 'indices'.
|
||||
SmallVector<OperationStmt *, 4> affineApplyOps;
|
||||
getReachableAffineApplyOps(valueMap->getOperands(), affineApplyOps);
|
||||
// Compose AffineApplyOps in 'affineApplyOps'.
|
||||
for (auto *opStmt : affineApplyOps) {
|
||||
assert(opStmt->isa<AffineApplyOp>());
|
||||
auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>();
|
||||
// Forward substitute 'affineApplyOp' into 'valueMap'.
|
||||
valueMap->forwardSubstitute(*affineApplyOp);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the memory region accessed by this memref.
|
||||
// TODO(bondhugula): extend this to store's and other memref dereferencing ops.
|
||||
bool getMemoryRegion(OpPointer<LoadOp> loadOp, FlatAffineConstraints *region) {
|
||||
|
|
|
@ -94,6 +94,15 @@ AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
|
|||
{getAffineConstantExpr(val, context)}, {});
|
||||
}
|
||||
|
||||
AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
|
||||
MLIRContext *context) {
|
||||
SmallVector<AffineExpr, 4> dimExprs;
|
||||
dimExprs.reserve(numDims);
|
||||
for (unsigned i = 0; i < numDims; ++i)
|
||||
dimExprs.push_back(mlir::getAffineDimExpr(i, context));
|
||||
return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, {});
|
||||
}
|
||||
|
||||
MLIRContext *AffineMap::getContext() const { return getResult(0).getContext(); }
|
||||
|
||||
bool AffineMap::isBounded() const { return !map->rangeSizes.empty(); }
|
||||
|
|
|
@ -0,0 +1,244 @@
|
|||
//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements loop fusion.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Loop fusion pass. This pass fuses adjacent loops in MLFunctions which
|
||||
/// access the same memref with no dependences.
|
||||
// See MatchTestPattern for details on candidate loop selection.
|
||||
// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
|
||||
// and add support for more general loop fusion algorithms.
|
||||
struct LoopFusion : public FunctionPass {
|
||||
LoopFusion() {}
|
||||
|
||||
PassResult runOnMLFunction(MLFunction *f) override;
|
||||
};
|
||||
|
||||
// LoopCollector walks the statements in an MLFunction and builds a map from
|
||||
// StmtBlocks to a list of loops within the StmtBlock, and a map from ForStmts
|
||||
// to the list of loads and stores with its StmtBlock.
|
||||
class LoopCollector : public StmtWalker<LoopCollector> {
|
||||
public:
|
||||
DenseMap<StmtBlock *, SmallVector<ForStmt *, 2>> loopMap;
|
||||
DenseMap<ForStmt *, SmallVector<OperationStmt *, 2>> loadsAndStoresMap;
|
||||
bool hasIfStmt = false;
|
||||
|
||||
void visitForStmt(ForStmt *forStmt) {
|
||||
loopMap[forStmt->getBlock()].push_back(forStmt);
|
||||
}
|
||||
|
||||
void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; }
|
||||
|
||||
void visitOperationStmt(OperationStmt *opStmt) {
|
||||
if (auto *parentStmt = opStmt->getParentStmt()) {
|
||||
if (auto *parentForStmt = dyn_cast<ForStmt>(parentStmt)) {
|
||||
if (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>()) {
|
||||
loadsAndStoresMap[parentForStmt].push_back(opStmt);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
|
||||
|
||||
// TODO(andydavis) Remove the following test code when more general loop
|
||||
// fusion is supported.
|
||||
struct FusionCandidate {
|
||||
// Loop nest of ForStmts with 'accessA' in the inner-most loop.
|
||||
SmallVector<ForStmt *, 2> forStmtsA;
|
||||
// Load or store operation within loop nest 'forStmtsA'.
|
||||
MemRefAccess accessA;
|
||||
// Loop nest of ForStmts with 'accessB' in the inner-most loop.
|
||||
SmallVector<ForStmt *, 2> forStmtsB;
|
||||
// Load or store operation within loop nest 'forStmtsB'.
|
||||
MemRefAccess accessB;
|
||||
};
|
||||
|
||||
static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt,
|
||||
MemRefAccess *access) {
|
||||
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
|
||||
access->memref = cast<MLValue>(loadOp->getMemRef());
|
||||
access->opStmt = loadOrStoreOpStmt;
|
||||
auto loadMemrefType = loadOp->getMemRefType();
|
||||
access->indices.reserve(loadMemrefType.getRank());
|
||||
for (auto *index : loadOp->getIndices()) {
|
||||
access->indices.push_back(cast<MLValue>(index));
|
||||
}
|
||||
} else {
|
||||
assert(loadOrStoreOpStmt->isa<StoreOp>());
|
||||
auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
|
||||
access->opStmt = loadOrStoreOpStmt;
|
||||
access->memref = cast<MLValue>(storeOp->getMemRef());
|
||||
auto storeMemrefType = storeOp->getMemRefType();
|
||||
access->indices.reserve(storeMemrefType.getRank());
|
||||
for (auto *index : storeOp->getIndices()) {
|
||||
access->indices.push_back(cast<MLValue>(index));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if 'forStmtA' and 'forStmtB' match specific test criterion:
|
||||
// constant loop bounds, no nested loops, single StoreOp in 'forStmtA' and
|
||||
// a single LoadOp in 'forStmtB'.
|
||||
// Returns true if the test pattern matches, false otherwise.
|
||||
static bool MatchTestPatternLoopPair(LoopCollector *lc,
|
||||
FusionCandidate *candidate,
|
||||
ForStmt *forStmtA, ForStmt *forStmtB) {
|
||||
if (forStmtA == nullptr || forStmtB == nullptr)
|
||||
return false;
|
||||
// Return if 'forStmtA' and 'forStmtB' do not have matching constant
|
||||
// bounds and step.
|
||||
if (!forStmtA->hasConstantBounds() || !forStmtB->hasConstantBounds() ||
|
||||
forStmtA->getConstantLowerBound() != forStmtB->getConstantLowerBound() ||
|
||||
forStmtA->getConstantUpperBound() != forStmtB->getConstantUpperBound() ||
|
||||
forStmtA->getStep() != forStmtB->getStep())
|
||||
return false;
|
||||
|
||||
// Return if 'forStmtA' or 'forStmtB' have nested loops.
|
||||
if (lc->loopMap.count(forStmtA) > 0 || lc->loopMap.count(forStmtB))
|
||||
return false;
|
||||
|
||||
// Return if 'forStmtA' or 'forStmtB' do not have exactly one load or store.
|
||||
if (lc->loadsAndStoresMap[forStmtA].size() != 1 ||
|
||||
lc->loadsAndStoresMap[forStmtB].size() != 1)
|
||||
return false;
|
||||
|
||||
// Get load/store access for forStmtA.
|
||||
getSingleMemRefAccess(lc->loadsAndStoresMap[forStmtA][0],
|
||||
&candidate->accessA);
|
||||
// Return if 'accessA' is not a store.
|
||||
if (!candidate->accessA.opStmt->isa<StoreOp>())
|
||||
return false;
|
||||
|
||||
// Get load/store access for forStmtB.
|
||||
getSingleMemRefAccess(lc->loadsAndStoresMap[forStmtB][0],
|
||||
&candidate->accessB);
|
||||
|
||||
// Return if accesses do not access the same memref.
|
||||
if (candidate->accessA.memref != candidate->accessB.memref)
|
||||
return false;
|
||||
|
||||
candidate->forStmtsA.push_back(forStmtA);
|
||||
candidate->forStmtsB.push_back(forStmtB);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns the child ForStmt of 'parent' if unique, returns false otherwise.
|
||||
ForStmt *getSingleForStmtChild(ForStmt *parent) {
|
||||
if (parent->getStatements().size() == 1 && isa<ForStmt>(parent->front()))
|
||||
return dyn_cast<ForStmt>(&parent->front());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Checks for a specific ForStmt/OpStatment test pattern in 'f', returns true
|
||||
// on success and resturns fusion candidate in 'candidate'. Returns false
|
||||
// otherwise.
|
||||
// Currently supported test patterns:
|
||||
// *) Adjacent loops with a StoreOp the only op in first loop, and a LoadOp the
|
||||
// only op in the second loop (both load/store accessing the same memref).
|
||||
// *) As above, but with one level of perfect loop nesting.
|
||||
//
|
||||
// TODO(andydavis) Look into using ntv@ pattern matcher here.
|
||||
static bool MatchTestPattern(MLFunction *f, FusionCandidate *candidate) {
|
||||
LoopCollector lc;
|
||||
lc.walk(f);
|
||||
// Return if an IfStmt was found or if less than two ForStmts were found.
|
||||
if (lc.hasIfStmt || lc.loopMap.count(f) == 0 || lc.loopMap[f].size() < 2)
|
||||
return false;
|
||||
auto *forStmtA = lc.loopMap[f][0];
|
||||
auto *forStmtB = lc.loopMap[f][1];
|
||||
if (!MatchTestPatternLoopPair(&lc, candidate, forStmtA, forStmtB)) {
|
||||
// Check for one level of loop nesting.
|
||||
candidate->forStmtsA.push_back(forStmtA);
|
||||
candidate->forStmtsB.push_back(forStmtB);
|
||||
return MatchTestPatternLoopPair(&lc, candidate,
|
||||
getSingleForStmtChild(forStmtA),
|
||||
getSingleForStmtChild(forStmtB));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// FuseLoops implements the code generation mechanics of loop fusion.
|
||||
// Fuses the operations statments from the inner-most loop in 'c.forStmtsB',
|
||||
// by cloning them into the inner-most loop in 'c.forStmtsA', then erasing
|
||||
// old statements and loops.
|
||||
static void fuseLoops(const FusionCandidate &c) {
|
||||
MLFuncBuilder builder(c.forStmtsA.back(),
|
||||
StmtBlock::iterator(c.forStmtsA.back()->end()));
|
||||
DenseMap<const MLValue *, MLValue *> operandMap;
|
||||
assert(c.forStmtsA.size() == c.forStmtsB.size());
|
||||
for (unsigned i = 0, e = c.forStmtsA.size(); i < e; i++) {
|
||||
// Map loop IVs to 'forStmtB[i]' to loop IV for 'forStmtA[i]'.
|
||||
operandMap[c.forStmtsB[i]] = c.forStmtsA[i];
|
||||
}
|
||||
// Clone the body of inner-most loop in 'forStmtsB', into the body of
|
||||
// inner-most loop in 'forStmtsA'.
|
||||
SmallVector<Statement *, 2> stmtsToErase;
|
||||
auto *innerForStmtB = c.forStmtsB.back();
|
||||
for (auto &stmt : *innerForStmtB) {
|
||||
builder.clone(stmt, operandMap);
|
||||
stmtsToErase.push_back(&stmt);
|
||||
}
|
||||
// Erase 'forStmtB' and its statement list.
|
||||
for (auto it = stmtsToErase.rbegin(); it != stmtsToErase.rend(); ++it)
|
||||
(*it)->erase();
|
||||
// Erase 'forStmtsB' loop nest.
|
||||
for (int i = static_cast<int>(c.forStmtsB.size()) - 1; i >= 0; --i)
|
||||
c.forStmtsB[i]->erase();
|
||||
}
|
||||
|
||||
PassResult LoopFusion::runOnMLFunction(MLFunction *f) {
|
||||
FusionCandidate candidate;
|
||||
if (!MatchTestPattern(f, &candidate))
|
||||
return failure();
|
||||
|
||||
// TODO(andydavis) Add checks for fusion-preventing dependences and ordering
|
||||
// constraints which would prevent fusion.
|
||||
// TODO(andydavis) This check if overly conservative for now. Support fusing
|
||||
// statements with compatible dependences (i.e. statements where the
|
||||
// dependence between the statements does not reverse direction when the
|
||||
// statements are fused into the same loop).
|
||||
if (!checkMemrefAccessDependence(candidate.accessA, candidate.accessB)) {
|
||||
// Current conservatinve test policy: No dependence exists between accesses
|
||||
// in different loop nests -> fuse loops.
|
||||
fuseLoops(candidate);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
|
@ -0,0 +1,191 @@
|
|||
// RUN: mlir-opt %s -loop-fusion | FileCheck %s
|
||||
|
||||
// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0 * 2 + 2)
|
||||
// CHECK: [[MAP1:#map[0-9]+]] = (d0) -> (d0 * 3 + 1)
|
||||
// CHECK: [[MAP2:#map[0-9]+]] = (d0) -> (d0 * 2)
|
||||
// CHECK: [[MAP3:#map[0-9]+]] = (d0) -> (d0 * 2 + 1)
|
||||
// CHECK: [[MAP4:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * 2 - d1 - s0 * 7 + 3, d0 * 9 + d1 * 3 + s1 * 13 - 10)
|
||||
// CHECK: [[MAP5:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * 2 - 1, d1 * 3 + s0 * 2 + s1 * 3)
|
||||
// CHECK: [[MAP6:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * 2 - 1, d1 * 3 + s0 + s1 * 3)
|
||||
|
||||
// The dependence check for this test builds the following set of constraints,
|
||||
// where the equality contraint equates the two accesses to the memref (from
|
||||
// different loops), and the inequality constraints represent the upper and
|
||||
// lower bounds for each loop. After elimination, this linear system can be
|
||||
// shown to be non-empty (i.e. x0 = x1 = 1 is a solution). As such, the
|
||||
// dependence check between accesses in the two loops will return true, and
|
||||
// the loops (according to the current test loop fusion algorithm) should not be
|
||||
// fused.
|
||||
//
|
||||
// x0 x1 x2
|
||||
// 2 -3 1 = 0
|
||||
// 1 0 0 >= 0
|
||||
// -1 0 100 >= 0
|
||||
// 0 1 0 >= 0
|
||||
// 0 -1 100 >= 0
|
||||
//
|
||||
// CHECK-LABEL: mlfunc @loop_fusion_1d_should_not_fuse_loops() {
|
||||
mlfunc @loop_fusion_1d_should_not_fuse_loops() {
|
||||
%m = alloc() : memref<100xf32, (d0) -> (d0)>
|
||||
// Check that the first loop remains unfused.
|
||||
// CHECK: for %i0 = 0 to 100 {
|
||||
// CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP0]](%i0)
|
||||
// CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]{{\]}}
|
||||
// CHECK-NEXT: }
|
||||
for %i0 = 0 to 100 {
|
||||
%a0 = affine_apply (d0) -> (d0 * 2 + 2) (%i0)
|
||||
%c1 = constant 1.0 : f32
|
||||
store %c1, %m[%a0] : memref<100xf32, (d0) -> (d0)>
|
||||
}
|
||||
// Check that the second loop remains unfused.
|
||||
// CHECK: for %i1 = 0 to 100 {
|
||||
// CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP1]](%i1)
|
||||
// CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]{{\]}}
|
||||
// CHECK-NEXT: }
|
||||
for %i1 = 0 to 100 {
|
||||
%a1 = affine_apply (d0) -> (d0 * 3 + 1) (%i1)
|
||||
%v0 = load %m[%a1] : memref<100xf32, (d0) -> (d0)>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// The dependence check for this test builds the following set of constraints:
|
||||
//
|
||||
// x0 x1 x2
|
||||
// 2 -2 -1 = 0
|
||||
// 1 0 0 >= 0
|
||||
// -1 0 100 >= 0
|
||||
// 0 1 0 >= 0
|
||||
// 0 -1 100 >= 0
|
||||
//
|
||||
// After elimination, this linear system can be shown to have no solutions, and
|
||||
// so no dependence exists and the loops should be fused in this test (according
|
||||
// to the current trivial test loop fusion policy).
|
||||
//
|
||||
//
|
||||
// CHECK-LABEL: mlfunc @loop_fusion_1d_should_fuse_loops() {
|
||||
mlfunc @loop_fusion_1d_should_fuse_loops() {
|
||||
%m = alloc() : memref<100xf32, (d0) -> (d0)>
|
||||
// Should fuse statements from the second loop into the first loop.
|
||||
// CHECK: for %i0 = 0 to 100 {
|
||||
// CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP2]](%i0)
|
||||
// CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]{{\]}}
|
||||
// CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP3]](%i0)
|
||||
// CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]{{\]}}
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
for %i0 = 0 to 100 {
|
||||
%a0 = affine_apply (d0) -> (d0 * 2) (%i0)
|
||||
%c1 = constant 1.0 : f32
|
||||
store %c1, %m[%a0] : memref<100xf32, (d0) -> (d0)>
|
||||
}
|
||||
|
||||
for %i1 = 0 to 100 {
|
||||
%a1 = affine_apply (d0) -> (d0 * 2 + 1) (%i1)
|
||||
|
||||
%v0 = load %m[%a1] : memref<100xf32, (d0) -> (d0)>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// The dependence check for this test builds the following set of
|
||||
// equality constraints (one for each memref dimension). Note: inequality
|
||||
// constraints for loop bounds not shown.
|
||||
//
|
||||
// i0 i1 i2 i3 s0 s1 s2 c
|
||||
// 2 -1 -2 0 -7 0 0 4 = 0
|
||||
// 9 3 0 -3 0 11 -3 -10 = 0
|
||||
//
|
||||
// CHECK-LABEL: mlfunc @loop_fusion_2d_should_not_fuse_loops() {
|
||||
mlfunc @loop_fusion_2d_should_not_fuse_loops() {
|
||||
%m = alloc() : memref<10x10xf32>
|
||||
|
||||
%s0 = constant 7 : index
|
||||
%s1 = constant 11 : index
|
||||
%s2 = constant 13 : index
|
||||
// Check that the first loop remains unfused.
|
||||
// CHECK: for %i0 = 0 to 100 {
|
||||
// CHECK-NEXT: for %i1 = 0 to 50 {
|
||||
// CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP4]](%i0, %i1)[%c7, %c11]
|
||||
// CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]#0, [[I0]]#1{{\]}}
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
for %i0 = 0 to 100 {
|
||||
for %i1 = 0 to 50 {
|
||||
%a0 = affine_apply
|
||||
(d0, d1)[s0, s1] ->
|
||||
(d0 * 2 -d1 + -7 * s0 + 3 , d0 * 9 + d1 * 3 + 13 * s1 - 10)
|
||||
(%i0, %i1)[%s0, %s1]
|
||||
%c1 = constant 1.0 : f32
|
||||
store %c1, %m[%a0#0, %a0#1] : memref<10x10xf32>
|
||||
}
|
||||
}
|
||||
// Check that the second loop remains unfused.
|
||||
// CHECK: for %i2 = 0 to 100 {
|
||||
// CHECK-NEXT: for %i3 = 0 to 50 {
|
||||
// CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP5]](%i2, %i3)[%c11, %c13]
|
||||
// CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]#0, [[I1]]#1{{\]}}
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
for %i2 = 0 to 100 {
|
||||
for %i3 = 0 to 50 {
|
||||
%a1 = affine_apply
|
||||
(d0, d1)[s0, s1] ->
|
||||
(d0 * 2 - 1, d1 * 3 + s0 * 2 + s1 * 3) (%i2, %i3)[%s1, %s2]
|
||||
%v0 = load %m[%a1#0, %a1#1] : memref<10x10xf32>
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// The dependence check for this test builds the following set of
|
||||
// equality constraints (one for each memref dimension). Note: inequality
|
||||
// constraints for loop bounds not shown.
|
||||
//
|
||||
// i0 i1 i2 i3 s0 s1 s2 c
|
||||
// 2 -1 -2 0 -7 0 0 4 = 0
|
||||
// 9 3 0 -3 0 12 -3 -10 = 0
|
||||
//
|
||||
// The second equality will fail the GCD test and so the system has no solution,
|
||||
// so the loops should be fused under the current test policy.
|
||||
//
|
||||
// CHECK-LABEL: mlfunc @loop_fusion_2d_should_fuse_loops() {
|
||||
mlfunc @loop_fusion_2d_should_fuse_loops() {
|
||||
%m = alloc() : memref<10x10xf32>
|
||||
|
||||
%s0 = constant 7 : index
|
||||
%s1 = constant 11 : index
|
||||
%s2 = constant 13 : index
|
||||
// Should fuse statements from the second loop into the first loop.
|
||||
// CHECK: for %i0 = 0 to 100 {
|
||||
// CHECK-NEXT: for %i1 = 0 to 50 {
|
||||
// CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP4]](%i0, %i1)[%c7, %c11]
|
||||
// CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]#0, [[I0]]#1{{\]}}
|
||||
// CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP6]](%i0, %i1)[%c11, %c13]
|
||||
// CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]#0, [[I1]]#1{{\]}}
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
for %i0 = 0 to 100 {
|
||||
for %i1 = 0 to 50 {
|
||||
%a0 = affine_apply
|
||||
(d0, d1)[s0, s1] ->
|
||||
(d0 * 2 -d1 + -7 * s0 + 3 , d0 * 9 + d1 * 3 + 13 * s1 - 10)
|
||||
(%i0, %i1)[%s0, %s1]
|
||||
%c1 = constant 1.0 : f32
|
||||
store %c1, %m[%a0#0, %a0#1] : memref<10x10xf32>
|
||||
}
|
||||
}
|
||||
|
||||
for %i2 = 0 to 100 {
|
||||
for %i3 = 0 to 50 {
|
||||
%a1 = affine_apply
|
||||
(d0, d1)[s0, s1] ->
|
||||
(d0 * 2 - 1, d1 * 3 + s0 + s1 * 3) (%i2, %i3)[%s1, %s2]
|
||||
%v0 = load %m[%a1#0, %a1#1] : memref<10x10xf32>
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -72,6 +72,7 @@ enum Passes {
|
|||
ConstantFold,
|
||||
ConvertToCFG,
|
||||
MemRefBoundCheck,
|
||||
LoopFusion,
|
||||
LoopUnroll,
|
||||
LoopUnrollAndJam,
|
||||
PipelineDataTransfer,
|
||||
|
@ -94,6 +95,7 @@ static cl::list<Passes> passList(
|
|||
"Convert all ML functions in the module to CFG ones"),
|
||||
clEnumValN(MemRefBoundCheck, "memref-bound-check",
|
||||
"Convert all ML functions in the module to CFG ones"),
|
||||
clEnumValN(LoopFusion, "loop-fusion", "Fuse loop nests"),
|
||||
clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
|
||||
clEnumValN(LoopUnrollAndJam, "loop-unroll-jam", "Unroll and jam loops"),
|
||||
clEnumValN(PipelineDataTransfer, "pipeline-data-transfer",
|
||||
|
@ -198,6 +200,9 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) {
|
|||
case MemRefBoundCheck:
|
||||
pass = createMemRefBoundCheckPass();
|
||||
break;
|
||||
case LoopFusion:
|
||||
pass = createLoopFusionPass();
|
||||
break;
|
||||
case LoopUnroll:
|
||||
pass = createLoopUnrollPass();
|
||||
break;
|
||||
|
|
Loading…
Reference in New Issue