Adds a dependence check to test whether two accesses to the same memref access the same element.

- Builds access functions and iterations domains for each access.
- Builds dependence polyhedron constraint system which has equality constraints for equated access functions and inequality constraints for iteration domain loop bounds.
- Runs elimination on the dependence polyhedron to test if no dependence exists between the accesses.
- Adds a trivial LoopFusion transformation pass with a simple test policy to test dependence between accesses to the same memref in adjacent loops.
- The LoopFusion pass will be extended in subsequent CLs.

PiperOrigin-RevId: 219630898
This commit is contained in:
MLIR Team 2018-11-01 07:26:00 -07:00 committed by jpienaar
parent 21638dcda9
commit f28e4df666
11 changed files with 985 additions and 17 deletions

View File

@ -31,6 +31,9 @@
namespace mlir {
class AffineExpr;
class AffineValueMap;
class ForStmt;
class MLIRContext;
class MLValue;
class OperationStmt;
@ -48,6 +51,10 @@ void getReachableAffineApplyOps(
llvm::ArrayRef<MLValue *> operands,
llvm::SmallVectorImpl<OperationStmt *> &affineApplyOps);
/// Forward substitutes into 'valueMap' all AffineApplyOps reachable from the
/// operands of 'valueMap'.
void forwardSubstituteReachableOps(AffineValueMap *valueMap);
/// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false
/// if 'expr' was unable to be flattened (i.e. because it was not pure affine,
/// or because it contained mod's and div's that could not be eliminated
@ -56,6 +63,22 @@ bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
unsigned numSymbols,
llvm::SmallVectorImpl<int64_t> *flattenedExpr);
/// Checks whether two accesses to the same memref access the same element.
/// Each access is specified using the MemRefAccess structure, which contains
/// the operation statement, indices and memref associated with the access.
/// Returns 'false' if it can be determined conclusively that the accesses do
/// not access the same memref element. Returns 'true' otherwise.
struct MemRefAccess {
const MLValue *memref;
const OperationStmt *opStmt;
llvm::SmallVector<MLValue *, 4> indices;
// Populates 'accessMap' with composition of AffineApplyOps reachable from
// 'indices'.
void getAccessMap(AffineValueMap *accessMap) const;
};
bool checkMemrefAccessDependence(const MemRefAccess &srcAccess,
const MemRefAccess &dstAccess);
} // end namespace mlir
#endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H

View File

@ -42,6 +42,7 @@ class HyperRectangularSet;
/// A mutable affine map. Its affine expressions are however unique.
struct MutableAffineMap {
public:
MutableAffineMap() {}
MutableAffineMap(AffineMap map);
AffineExpr getResult(unsigned idx) const { return results[idx]; }
@ -56,6 +57,9 @@ public:
/// Returns true if the idx'th result expression is a multiple of factor.
bool isMultipleOf(unsigned idx, int64_t factor) const;
/// Resets this MutableAffineMap with 'map'.
void reset(AffineMap map);
/// Simplify the (result) expressions in this map using analysis (used by
//-simplify-affine-expr pass).
void simplify();
@ -120,6 +124,9 @@ private:
// TODO(bondhugula): Some of these classes could go into separate files.
class AffineValueMap {
public:
// Creates an empty AffineValueMap (users should call 'reset' to reset map
// and operands).
AffineValueMap() {}
AffineValueMap(const AffineApplyOp &op);
AffineValueMap(const AffineBound &bound);
AffineValueMap(AffineMap map);
@ -145,6 +152,8 @@ public:
// can be used to amortize the cost of simplification over multiple fwd
// substitutions).
// Resets this AffineValueMap with 'map' and 'operands'.
void reset(AffineMap map, ArrayRef<MLValue *> operands);
/// Return true if the idx^th result can be proved to be a multiple of
/// 'factor', false otherwise.
inline bool isMultipleOf(unsigned idx, int64_t factor) const;

View File

@ -56,6 +56,10 @@ public:
/// Returns a single constant result affine map.
static AffineMap getConstantMap(int64_t val, MLIRContext *context);
/// Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap getMultiDimIdentityMap(unsigned numDims,
MLIRContext *context);
MLIRContext *getContext() const;
explicit operator bool() { return map; }

View File

@ -52,6 +52,9 @@ FunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1);
/// Creates an simplification pass for affine structures.
FunctionPass *createSimplifyAffineStructuresPass();
/// Creates a loop fusion pass which fuses loops in MLFunctions.
FunctionPass *createLoopFusionPass();
/// Creates a pass to pipeline explicit movement of data across levels of the
/// memory hierarchy.
FunctionPass *createPipelineDataTransferPass();

View File

@ -25,6 +25,8 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@ -370,3 +372,475 @@ void mlir::getReachableAffineApplyOps(
}
}
}
// Forward substitutes into 'valueMap' all AffineApplyOps reachable from the
// operands of 'valueMap'.
void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) {
// Gather AffineApplyOps reachable from 'indices'.
SmallVector<OperationStmt *, 4> affineApplyOps;
getReachableAffineApplyOps(valueMap->getOperands(), affineApplyOps);
// Compose AffineApplyOps in 'affineApplyOps'.
for (auto *opStmt : affineApplyOps) {
assert(opStmt->isa<AffineApplyOp>());
auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>();
// Forward substitute 'affineApplyOp' into 'valueMap'.
valueMap->forwardSubstitute(*affineApplyOp);
}
}
// Adds loop upper and lower bound inequalities to 'domain' for each ForStmt
// value in 'values'. Requires that the first 'numDims' MLValues in 'values'
// are ForStmts. Returns true if lower/upper bound inequalities were
// successfully added, returns false otherwise.
// TODO(andydavis) Get operands for loop bounds so we can add domain
// constraints for non-constant loop bounds.
static bool addLoopBoundInequalities(unsigned numDims,
ArrayRef<const MLValue *> values,
FlatAffineConstraints *domain) {
assert(values.size() >= numDims);
unsigned numIds = values.size();
// Add InEqualties for loop bounds.
SmallVector<int64_t, 4> ineq;
ineq.resize(numIds + 1);
for (unsigned i = 0; i < numDims; ++i) {
const ForStmt *forStmt = dyn_cast<ForStmt>(values[i]);
if (!forStmt || !forStmt->hasConstantBounds())
return false;
// Zero fill
std::fill(ineq.begin(), ineq.end(), 0);
// TODO(andydavis, bondhugula) Add methods for addUpper/LowerBound.
// Add inequality for lower bound.
ineq[i] = 1;
ineq[numIds] = -forStmt->getConstantLowerBound();
domain->addInequality(ineq);
// Add inequality for upper bound.
ineq[i] = -1;
ineq[numIds] = forStmt->getConstantUpperBound();
domain->addInequality(ineq);
}
return true;
}
// IterationDomainContext encapsulates the state required to represent
// the iteration domain of an OperationStmt.
struct IterationDomainContext {
// Set of inequality constraint pairs, where each pair represents the
// upper/lower bounds of a ForStmt in the iteration domain.
FlatAffineConstraints domain;
// The number of dimension identifiers in 'values'.
unsigned numDims;
// The list of MLValues in this iteration domain, with MLValues in
// [0, numDims) representing dimension identifiers, and MLValues in
// [numDims, values.size()) representing symbol identifiers.
SmallVector<const MLValue *, 4> values;
IterationDomainContext() : numDims(0) {}
unsigned getNumDims() { return numDims; }
unsigned getNumSymbols() { return values.size() - numDims; }
};
// Computes the iteration domain for 'opStmt' and populates 'ctx', which
// encapsulates the following state for each ForStmt in 'opStmt's iteration
// domain:
// *) adds inequality constraints representing the ForStmt upper/lower bounds.
// *) adds MLValues and symbols for the ForStmt and its operands to a list.
// TODO(andydavis) Add support for IfStmts in iteration domain.
// TODO(andydavis) Handle non-constant loop bounds by composing affine maps
// for each ForStmt loop bound and adding de-duped ids/symbols to iteration
// domain context.
// TODO(andydavis) Capture the context of the symbols. For example, check
// if a symbol is the result of a constant operation, and set the symbol to
// that value in FlatAffineConstraints (using setIdToConstant).
static bool getIterationDomainContext(const OperationStmt *opStmt,
IterationDomainContext *ctx) {
// Walk up tree storing parent statements in 'loops'.
// TODO(andydavis) Extend this to gather enclosing IfStmts and consider
// factoring it out into a utility function.
SmallVector<const ForStmt *, 4> loops;
const auto *currStmt = opStmt->getParentStmt();
while (currStmt != nullptr) {
if (isa<IfStmt>(currStmt))
return false;
assert(isa<ForStmt>(currStmt));
auto *forStmt = dyn_cast<ForStmt>(currStmt);
loops.push_back(forStmt);
currStmt = currStmt->getParentStmt();
}
// Iterate through 'loops' from outer-most loop to inner-most loop.
// Populate 'values'.
ctx->values.reserve(loops.size());
for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
auto *forStmt = loops[i];
// TODO(andydavis) Compose affine maps into lower/upper bounds of 'forStmt'
// and add de-duped symbols to ctx.symbols.
if (!forStmt->hasConstantBounds())
return false;
ctx->values.push_back(forStmt);
ctx->numDims++;
}
// Resize flat affine constraint system based on num dims symbols found.
unsigned numDims = ctx->getNumDims();
unsigned numSymbols = ctx->getNumSymbols();
ctx->domain.reset(/*newNumReservedInequalities=*/2 * numDims,
/*newNumReservedEqualities=*/0,
/*newNumReservedCols=*/numDims + numSymbols + 1, numDims,
numSymbols);
return addLoopBoundInequalities(numDims, ctx->values, &ctx->domain);
}
// Builds a map from MLValue to identifier position in a new merged identifier
// list, which is the result of merging dim/symbol lists from src/dst
// iteration domains. The format of the new merged list is as follows:
//
// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers]
//
// This method populates 'srcDimPosMap' and 'dstDimPosMap' with mappings from
// operand MLValues in 'srcAccessMap'/'dstAccessMap' to the position of these
// values in the merged list.
// In addition, this method populates 'symbolPosMap' with mappings from
// operand MLValues in both 'srcIterationDomainContext' and
// 'dstIterationDomainContext' to position of these values in the merged list.
static void buildDimAndSymbolPositionMaps(
const IterationDomainContext &srcIterationDomainContext,
const IterationDomainContext &dstIterationDomainContext,
const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap,
DenseMap<const MLValue *, unsigned> *srcDimPosMap,
DenseMap<const MLValue *, unsigned> *dstDimPosMap,
DenseMap<const MLValue *, unsigned> *symbolPosMap) {
unsigned pos = 0;
auto updatePosMap = [&](DenseMap<const MLValue *, unsigned> *posMap,
ArrayRef<const MLValue *> values, unsigned start,
unsigned limit) {
for (unsigned i = start; i < limit; ++i) {
auto *value = values[i];
auto it = posMap->find(value);
if (it == posMap->end()) {
(*posMap)[value] = pos++;
}
}
};
AffineMap srcMap = srcAccessMap.getAffineMap();
AffineMap dstMap = dstAccessMap.getAffineMap();
// Update position map with src dimension identifiers from iteration domain
// and access function.
updatePosMap(srcDimPosMap, srcIterationDomainContext.values, 0,
srcIterationDomainContext.numDims);
// Update position map with dst dimension identifiers from iteration domain
// and access function.
updatePosMap(dstDimPosMap, dstIterationDomainContext.values, 0,
dstIterationDomainContext.numDims);
// Update position map with src symbol identifiers from iteration domain
// and access function.
updatePosMap(symbolPosMap, srcIterationDomainContext.values,
dstIterationDomainContext.numDims,
srcIterationDomainContext.values.size());
updatePosMap(symbolPosMap, srcAccessMap.getOperands(), srcMap.getNumDims(),
srcMap.getNumDims() + srcMap.getNumSymbols());
// Update position map with dst symbol identifiers from iteration domain
// and access function.
updatePosMap(symbolPosMap, dstIterationDomainContext.values,
dstIterationDomainContext.numDims,
dstIterationDomainContext.values.size());
updatePosMap(symbolPosMap, dstAccessMap.getOperands(), dstMap.getNumDims(),
dstMap.getNumDims() + dstMap.getNumSymbols());
}
static unsigned getPos(const DenseMap<const MLValue *, unsigned> &posMap,
const MLValue *value) {
auto it = posMap.find(value);
assert(it != posMap.end());
return it->second;
}
// Adds iteration domain constraints from 'ctx.domain' into 'outputFac'.
// Uses 'dimPosMap' to map from dim operand value in 'ctx.values', to dim
// position in 'outputFac'.
// Uses 'symbolPosMap' to map from symbol operand value in 'ctx.values', to
// symbol position in 'outputFac'.
static void
addDomainConstraints(const IterationDomainContext &ctx,
const DenseMap<const MLValue *, unsigned> &dimPosMap,
const DenseMap<const MLValue *, unsigned> &symbolPosMap,
FlatAffineConstraints *outputFac) {
unsigned inputNumIneq = ctx.domain.getNumInequalities();
unsigned inputNumDims = ctx.domain.getNumDimIds();
unsigned inputNumSymbols = ctx.domain.getNumSymbolIds();
unsigned inputNumIds = inputNumDims + inputNumSymbols;
unsigned outputNumDims = outputFac->getNumDimIds();
unsigned outputNumSymbols = outputFac->getNumSymbolIds();
unsigned outputNumIds = outputNumDims + outputNumSymbols;
SmallVector<int64_t, 4> eq;
eq.resize(outputNumIds + 1);
for (unsigned i = 0; i < inputNumIneq; ++i) {
// Zero fill.
std::fill(eq.begin(), eq.end(), 0);
// Add dim identifiers.
for (unsigned j = 0; j < inputNumDims; ++j)
eq[getPos(dimPosMap, ctx.values[j])] = ctx.domain.atIneq(i, j);
// Add symbol identifiers.
for (unsigned j = inputNumDims; j < inputNumIds; ++j) {
eq[getPos(symbolPosMap, ctx.values[j])] = ctx.domain.atIneq(i, j);
}
// Add constant term.
eq[outputNumIds] = ctx.domain.atIneq(i, inputNumIds);
// Add inequality constraint.
outputFac->addInequality(eq);
}
}
// Adds equality constraints that equate src and dst access functions
// represented by 'srcAccessMap' and 'dstAccessMap' for each result.
// Requires that 'srcAccessMap' and 'dstAccessMap' have the same results count.
// For example, given the following two accesses functions to a 2D memref:
//
// Source access function:
// (a0 * d0 + a1 * s0 + a2, b0 * d0 + b1 * s0 + b2)
//
// Destination acceses function:
// (c0 * d0 + c1 * s0 + c2, f0 * d0 + f1 * s0 + f2)
//
// This method constructs the following equality constraints in 'outputFac',
// by equating the access functions for each result (i.e. each memref dim).
// (notice that 'd0' for the destination access function is mapped into 'd0'
// in the equality constraint):
//
// d0 d1 s0 c
// -- -- -- --
// a0 -c0 (a1 - c1) (a1 - c2) = 0
// b0 -f0 (b1 - f1) (b1 - f2) = 0
//
bool addMemRefAccessConstraints(
const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap,
const DenseMap<const MLValue *, unsigned> &srcDimPosMap,
const DenseMap<const MLValue *, unsigned> &dstDimPosMap,
const DenseMap<const MLValue *, unsigned> &symbolPosMap,
FlatAffineConstraints *outputFac) {
AffineMap srcMap = srcAccessMap.getAffineMap();
AffineMap dstMap = dstAccessMap.getAffineMap();
assert(srcMap.getNumResults() == dstMap.getNumResults());
unsigned numResults = srcMap.getNumResults();
unsigned srcNumDims = srcMap.getNumDims();
unsigned srcNumSymbols = srcMap.getNumSymbols();
unsigned srcNumIds = srcNumDims + srcNumSymbols;
ArrayRef<MLValue *> srcOperands = srcAccessMap.getOperands();
unsigned dstNumDims = dstMap.getNumDims();
unsigned dstNumSymbols = dstMap.getNumSymbols();
unsigned dstNumIds = dstNumDims + dstNumSymbols;
ArrayRef<MLValue *> dstOperands = dstAccessMap.getOperands();
unsigned outputNumDims = outputFac->getNumDimIds();
unsigned outputNumSymbols = outputFac->getNumSymbolIds();
unsigned outputNumIds = outputNumDims + outputNumSymbols;
SmallVector<int64_t, 4> eq;
eq.resize(outputNumIds + 1);
SmallVector<int64_t, 4> flattenedExpr;
for (unsigned i = 0; i < numResults; ++i) {
// Zero fill.
std::fill(eq.begin(), eq.end(), 0);
// Get flattened AffineExpr for result 'i' from src access function.
auto srcExpr = srcMap.getResult(i);
flattenedExpr.clear();
if (!getFlattenedAffineExpr(srcExpr, srcNumDims, srcNumSymbols,
&flattenedExpr))
return false;
// Add dim identifier coefficients from src access function.
for (unsigned j = 0, e = srcNumDims; j < e; ++j)
eq[getPos(srcDimPosMap, srcOperands[j])] = flattenedExpr[j];
// Add symbol identifiers from src access function.
for (unsigned j = srcNumDims; j < srcNumIds; ++j)
eq[getPos(symbolPosMap, srcOperands[j])] = flattenedExpr[j];
// Add constant term.
eq[outputNumIds] = flattenedExpr[srcNumIds];
// Get flattened AffineExpr for result 'i' from dst access function.
auto dstExpr = dstMap.getResult(i);
flattenedExpr.clear();
if (!getFlattenedAffineExpr(dstExpr, dstNumDims, dstNumSymbols,
&flattenedExpr))
return false;
// Add dim identifier coefficients from dst access function.
for (unsigned j = 0, e = dstNumDims; j < e; ++j)
eq[getPos(dstDimPosMap, dstOperands[j])] = -flattenedExpr[j];
// Add symbol identifiers from dst access function.
for (unsigned j = dstNumDims; j < dstNumIds; ++j)
eq[getPos(symbolPosMap, dstOperands[j])] -= flattenedExpr[j];
// Add constant term.
eq[outputNumIds] -= flattenedExpr[dstNumIds];
// Add equality constraint.
outputFac->addEquality(eq);
}
return true;
}
// Populates 'accessMap' with composition of AffineApplyOps reachable from
// indices of MemRefAccess.
void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
auto memrefType = memref->getType().cast<MemRefType>();
// Create identity map with same number of dimensions as 'memrefType' rank.
auto map = AffineMap::getMultiDimIdentityMap(memrefType.getRank(),
memref->getType().getContext());
// Reset 'accessMap' and 'map' and access 'indices'.
accessMap->reset(map, indices);
// Compose 'accessMap' with reachable AffineApplyOps.
forwardSubstituteReachableOps(accessMap);
}
// Builds a flat affine constraint system to check if there exists a dependence
// between memref accesses 'srcAccess' and 'dstAccess'.
// Returns 'false' if the accesses can be definitively shown not to access the
// same element. Returns 'true' otherwise.
//
// The memref access dependence check is comprised of the following steps:
// *) Compute access functions for each access. Access functions are computed
// using AffineValueMaps initialized with the indices from an access, then
// composed with AffineApplyOps reachable from operands of that access,
// until operands of the AffineValueMap are loop IVs or symbols.
// *) Build iteration domain constraints for each access. Iteration domain
// constraints are pairs of inequality contraints representing the
// upper/lower loop bounds for each ForStmt in the loop nest associated
// with each access.
// *) Build dimension and symbol position maps for each access, which map
// MLValues from access functions and iteration domains to their position
// in the merged constraint system build by this method.
//
// This method builds a constraint system with the following column format:
//
// [src-dim-identifiers, dst-dim-identifiers, symbols, constant]
//
// For example, given the following MLIR code with with "source" and
// "destination" accesses to the same memref labled, and symbols %M, %N, %K:
//
// for %i0 = 0 to 100 {
// for %i1 = 0 to 50 {
// %a0 = affine_apply
// (d0, d1) -> (d0 * 2 - d1 * 4 + s1, d1 * 3 - s0) (%i0, %i1)[%M, %N]
// // Source memref access.
// store %v0, %m[%a0#0, %a0#1] : memref<4x4xf32>
// }
// }
//
// for %i2 = 0 to 100 {
// for %i3 = 0 to 50 {
// %a1 = affine_apply
// (d0, d1) -> (d0 * 7 + d1 * 9 - s1, d1 * 11 + s0) (%i2, %i3)[%K, %M]
// // Destination memref access.
// %v1 = load %m[%a1#0, %a1#1] : memref<4x4xf32>
// }
// }
//
// The access functions would be the following:
//
// src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M)
// src: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K)
//
// The iteration domains for the src/dst accesses would be the following:
//
// src: 0 <= %i0 <= 100, 0 <= %i1 <= 50
// dst: 0 <= %i2 <= 100, 0 <= %i3 <= 50
//
// The symbols by both accesses would be assigned to a canonical position order
// which will be used in the dependence constraint system:
//
// symbol name: %M %N %K
// symbol pos: 0 1 2
//
// Equality constraints are built by equating each result of src/destination
// access functions. For this example, the folloing two equality constraints
// will be added to the dependence constraint system:
//
// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
// 2 -4 -7 -9 1 1 0 0 = 0
// 0 3 0 -11 -1 0 1 0 = 0
//
// Inequality constraints from the iteration domain will be meged into
// the dependence constraint system
//
// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
// 1 0 0 0 0 0 0 0 >= 0
// -1 0 0 0 0 0 0 100 >= 0
// 0 1 0 0 0 0 0 0 >= 0
// 0 -1 0 0 0 0 0 50 >= 0
// 0 0 1 0 0 0 0 0 >= 0
// 0 0 -1 0 0 0 0 100 >= 0
// 0 0 0 1 0 0 0 0 >= 0
// 0 0 0 -1 0 0 0 50 >= 0
//
//
// TODO(andydavis) Support AffineExprs mod/floordiv/ceildiv.
// TODO(andydavis) Add precedence order constraints for accesses that
// share a common loop.
bool mlir::checkMemrefAccessDependence(const MemRefAccess &srcAccess,
const MemRefAccess &dstAccess) {
// Return 'false' if these accesses do not acces the same memref.
if (srcAccess.memref != dstAccess.memref)
return false;
// Get composed access function for 'srcAccess'.
AffineValueMap srcAccessMap;
srcAccess.getAccessMap(&srcAccessMap);
// Get composed access function for 'dstAccess'.
AffineValueMap dstAccessMap;
dstAccess.getAccessMap(&dstAccessMap);
// Get iteration domain context for 'srcAccess'.
IterationDomainContext srcIterationDomainContext;
if (!getIterationDomainContext(srcAccess.opStmt, &srcIterationDomainContext))
return false;
// Get iteration domain context for 'dstAccess'.
IterationDomainContext dstIterationDomainContext;
if (!getIterationDomainContext(dstAccess.opStmt, &dstIterationDomainContext))
return false;
// Build dim and symbol position maps for each access from access operand
// MLValue to position in merged contstraint system.
DenseMap<const MLValue *, unsigned> srcDimPosMap;
DenseMap<const MLValue *, unsigned> dstDimPosMap;
DenseMap<const MLValue *, unsigned> symbolPosMap;
buildDimAndSymbolPositionMaps(
srcIterationDomainContext, dstIterationDomainContext, srcAccessMap,
dstAccessMap, &srcDimPosMap, &dstDimPosMap, &symbolPosMap);
// TODO(andydavis) Add documentation.
unsigned numIneq = srcIterationDomainContext.domain.getNumInequalities() +
dstIterationDomainContext.domain.getNumInequalities();
AffineMap srcMap = srcAccessMap.getAffineMap();
assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults());
unsigned numEq = srcMap.getNumResults();
unsigned numDims = srcDimPosMap.size() + dstDimPosMap.size();
unsigned numSymbols = symbolPosMap.size();
unsigned numIds = numDims + numSymbols;
unsigned numCols = numIds + 1;
// Create flat affine constraints reserving space for 'numEq' and 'numIneq'.
// TODO(andydavis) better name.
FlatAffineConstraints constraints(numIneq, numEq, numCols, numDims,
numSymbols);
// Create memref access constraint by equating src/dst access functions.
// Note that this check is conservative, and will failure in the future
// when local variables for mod/div exprs are supported.
if (!addMemRefAccessConstraints(srcAccessMap, dstAccessMap, srcDimPosMap,
dstDimPosMap, symbolPosMap, &constraints))
return true;
// Add domain constraints for src access function.
addDomainConstraints(srcIterationDomainContext, srcDimPosMap, symbolPosMap,
&constraints);
// Add equality constraints from 'dstConstraints'.
addDomainConstraints(dstIterationDomainContext, dstDimPosMap, symbolPosMap,
&constraints);
bool isEmpty = constraints.isEmpty();
// Return false if the solution space is empty.
return !isEmpty;
}

View File

@ -168,10 +168,9 @@ forwardSubstituteMutableAffineMap(const AffineMapCompositionUpdate &mapUpdate,
map->setNumDims(mapUpdate.outputNumDims);
map->setNumSymbols(mapUpdate.outputNumSymbols);
}
MutableAffineMap::MutableAffineMap(AffineMap map)
: numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
// A map always has at leat 1 result by construction
// A map always has at least 1 result by construction
context(map.getResult(0).getContext()) {
for (auto result : map.getResults())
results.push_back(result);
@ -179,6 +178,19 @@ MutableAffineMap::MutableAffineMap(AffineMap map)
results.push_back(rangeSize);
}
void MutableAffineMap::reset(AffineMap map) {
results.clear();
rangeSizes.clear();
numDims = map.getNumDims();
numSymbols = map.getNumSymbols();
// A map always has at least 1 result by construction
context = map.getResult(0).getContext();
for (auto result : map.getResults())
results.push_back(result);
for (auto rangeSize : map.getRangeSizes())
results.push_back(rangeSize);
}
bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
if (results[idx].isMultipleOf(factor))
return true;
@ -228,6 +240,15 @@ AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<MLValue *> operands)
}
}
void AffineValueMap::reset(AffineMap map, ArrayRef<MLValue *> operands) {
this->operands.clear();
this->results.clear();
this->map.reset(map);
for (MLValue *operand : operands) {
this->operands.push_back(operand);
}
}
void AffineValueMap::forwardSubstitute(const AffineApplyOp &inputOp) {
SmallVector<bool, 4> inputResultsToSubstitute(inputOp.getNumResults());
for (unsigned i = 0, e = inputOp.getNumResults(); i < e; i++)

View File

@ -53,21 +53,6 @@ FunctionPass *mlir::createMemRefBoundCheckPass() {
return new MemRefBoundCheck();
}
// Forward substitutes into 'valueMap' all AffineApplyOps reachable from the
// operands of 'valueMap'.
static void forwardSubstituteReachableOps(AffineValueMap *valueMap) {
// Gather AffineApplyOps reachable from 'indices'.
SmallVector<OperationStmt *, 4> affineApplyOps;
getReachableAffineApplyOps(valueMap->getOperands(), affineApplyOps);
// Compose AffineApplyOps in 'affineApplyOps'.
for (auto *opStmt : affineApplyOps) {
assert(opStmt->isa<AffineApplyOp>());
auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>();
// Forward substitute 'affineApplyOp' into 'valueMap'.
valueMap->forwardSubstitute(*affineApplyOp);
}
}
/// Returns the memory region accessed by this memref.
// TODO(bondhugula): extend this to store's and other memref dereferencing ops.
bool getMemoryRegion(OpPointer<LoadOp> loadOp, FlatAffineConstraints *region) {

View File

@ -94,6 +94,15 @@ AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
{getAffineConstantExpr(val, context)}, {});
}
AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
MLIRContext *context) {
SmallVector<AffineExpr, 4> dimExprs;
dimExprs.reserve(numDims);
for (unsigned i = 0; i < numDims; ++i)
dimExprs.push_back(mlir::getAffineDimExpr(i, context));
return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, {});
}
MLIRContext *AffineMap::getContext() const { return getResult(0).getContext(); }
bool AffineMap::isBounded() const { return !map->rangeSizes.empty(); }

View File

@ -0,0 +1,244 @@
//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements loop fusion.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
using namespace mlir;
namespace {
/// Loop fusion pass. This pass fuses adjacent loops in MLFunctions which
/// access the same memref with no dependences.
// See MatchTestPattern for details on candidate loop selection.
// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
// and add support for more general loop fusion algorithms.
struct LoopFusion : public FunctionPass {
LoopFusion() {}
PassResult runOnMLFunction(MLFunction *f) override;
};
// LoopCollector walks the statements in an MLFunction and builds a map from
// StmtBlocks to a list of loops within the StmtBlock, and a map from ForStmts
// to the list of loads and stores with its StmtBlock.
class LoopCollector : public StmtWalker<LoopCollector> {
public:
DenseMap<StmtBlock *, SmallVector<ForStmt *, 2>> loopMap;
DenseMap<ForStmt *, SmallVector<OperationStmt *, 2>> loadsAndStoresMap;
bool hasIfStmt = false;
void visitForStmt(ForStmt *forStmt) {
loopMap[forStmt->getBlock()].push_back(forStmt);
}
void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; }
void visitOperationStmt(OperationStmt *opStmt) {
if (auto *parentStmt = opStmt->getParentStmt()) {
if (auto *parentForStmt = dyn_cast<ForStmt>(parentStmt)) {
if (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>()) {
loadsAndStoresMap[parentForStmt].push_back(opStmt);
}
}
}
}
};
} // end anonymous namespace
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
// TODO(andydavis) Remove the following test code when more general loop
// fusion is supported.
struct FusionCandidate {
// Loop nest of ForStmts with 'accessA' in the inner-most loop.
SmallVector<ForStmt *, 2> forStmtsA;
// Load or store operation within loop nest 'forStmtsA'.
MemRefAccess accessA;
// Loop nest of ForStmts with 'accessB' in the inner-most loop.
SmallVector<ForStmt *, 2> forStmtsB;
// Load or store operation within loop nest 'forStmtsB'.
MemRefAccess accessB;
};
static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt,
MemRefAccess *access) {
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
access->memref = cast<MLValue>(loadOp->getMemRef());
access->opStmt = loadOrStoreOpStmt;
auto loadMemrefType = loadOp->getMemRefType();
access->indices.reserve(loadMemrefType.getRank());
for (auto *index : loadOp->getIndices()) {
access->indices.push_back(cast<MLValue>(index));
}
} else {
assert(loadOrStoreOpStmt->isa<StoreOp>());
auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
access->opStmt = loadOrStoreOpStmt;
access->memref = cast<MLValue>(storeOp->getMemRef());
auto storeMemrefType = storeOp->getMemRefType();
access->indices.reserve(storeMemrefType.getRank());
for (auto *index : storeOp->getIndices()) {
access->indices.push_back(cast<MLValue>(index));
}
}
}
// Checks if 'forStmtA' and 'forStmtB' match specific test criterion:
// constant loop bounds, no nested loops, single StoreOp in 'forStmtA' and
// a single LoadOp in 'forStmtB'.
// Returns true if the test pattern matches, false otherwise.
static bool MatchTestPatternLoopPair(LoopCollector *lc,
FusionCandidate *candidate,
ForStmt *forStmtA, ForStmt *forStmtB) {
if (forStmtA == nullptr || forStmtB == nullptr)
return false;
// Return if 'forStmtA' and 'forStmtB' do not have matching constant
// bounds and step.
if (!forStmtA->hasConstantBounds() || !forStmtB->hasConstantBounds() ||
forStmtA->getConstantLowerBound() != forStmtB->getConstantLowerBound() ||
forStmtA->getConstantUpperBound() != forStmtB->getConstantUpperBound() ||
forStmtA->getStep() != forStmtB->getStep())
return false;
// Return if 'forStmtA' or 'forStmtB' have nested loops.
if (lc->loopMap.count(forStmtA) > 0 || lc->loopMap.count(forStmtB))
return false;
// Return if 'forStmtA' or 'forStmtB' do not have exactly one load or store.
if (lc->loadsAndStoresMap[forStmtA].size() != 1 ||
lc->loadsAndStoresMap[forStmtB].size() != 1)
return false;
// Get load/store access for forStmtA.
getSingleMemRefAccess(lc->loadsAndStoresMap[forStmtA][0],
&candidate->accessA);
// Return if 'accessA' is not a store.
if (!candidate->accessA.opStmt->isa<StoreOp>())
return false;
// Get load/store access for forStmtB.
getSingleMemRefAccess(lc->loadsAndStoresMap[forStmtB][0],
&candidate->accessB);
// Return if accesses do not access the same memref.
if (candidate->accessA.memref != candidate->accessB.memref)
return false;
candidate->forStmtsA.push_back(forStmtA);
candidate->forStmtsB.push_back(forStmtB);
return true;
}
// Returns the child ForStmt of 'parent' if unique, returns false otherwise.
ForStmt *getSingleForStmtChild(ForStmt *parent) {
if (parent->getStatements().size() == 1 && isa<ForStmt>(parent->front()))
return dyn_cast<ForStmt>(&parent->front());
return nullptr;
}
// Checks for a specific ForStmt/OpStatment test pattern in 'f', returns true
// on success and resturns fusion candidate in 'candidate'. Returns false
// otherwise.
// Currently supported test patterns:
// *) Adjacent loops with a StoreOp the only op in first loop, and a LoadOp the
// only op in the second loop (both load/store accessing the same memref).
// *) As above, but with one level of perfect loop nesting.
//
// TODO(andydavis) Look into using ntv@ pattern matcher here.
static bool MatchTestPattern(MLFunction *f, FusionCandidate *candidate) {
LoopCollector lc;
lc.walk(f);
// Return if an IfStmt was found or if less than two ForStmts were found.
if (lc.hasIfStmt || lc.loopMap.count(f) == 0 || lc.loopMap[f].size() < 2)
return false;
auto *forStmtA = lc.loopMap[f][0];
auto *forStmtB = lc.loopMap[f][1];
if (!MatchTestPatternLoopPair(&lc, candidate, forStmtA, forStmtB)) {
// Check for one level of loop nesting.
candidate->forStmtsA.push_back(forStmtA);
candidate->forStmtsB.push_back(forStmtB);
return MatchTestPatternLoopPair(&lc, candidate,
getSingleForStmtChild(forStmtA),
getSingleForStmtChild(forStmtB));
}
return true;
}
// FuseLoops implements the code generation mechanics of loop fusion.
// Fuses the operations statments from the inner-most loop in 'c.forStmtsB',
// by cloning them into the inner-most loop in 'c.forStmtsA', then erasing
// old statements and loops.
static void fuseLoops(const FusionCandidate &c) {
MLFuncBuilder builder(c.forStmtsA.back(),
StmtBlock::iterator(c.forStmtsA.back()->end()));
DenseMap<const MLValue *, MLValue *> operandMap;
assert(c.forStmtsA.size() == c.forStmtsB.size());
for (unsigned i = 0, e = c.forStmtsA.size(); i < e; i++) {
// Map loop IVs to 'forStmtB[i]' to loop IV for 'forStmtA[i]'.
operandMap[c.forStmtsB[i]] = c.forStmtsA[i];
}
// Clone the body of inner-most loop in 'forStmtsB', into the body of
// inner-most loop in 'forStmtsA'.
SmallVector<Statement *, 2> stmtsToErase;
auto *innerForStmtB = c.forStmtsB.back();
for (auto &stmt : *innerForStmtB) {
builder.clone(stmt, operandMap);
stmtsToErase.push_back(&stmt);
}
// Erase 'forStmtB' and its statement list.
for (auto it = stmtsToErase.rbegin(); it != stmtsToErase.rend(); ++it)
(*it)->erase();
// Erase 'forStmtsB' loop nest.
for (int i = static_cast<int>(c.forStmtsB.size()) - 1; i >= 0; --i)
c.forStmtsB[i]->erase();
}
PassResult LoopFusion::runOnMLFunction(MLFunction *f) {
FusionCandidate candidate;
if (!MatchTestPattern(f, &candidate))
return failure();
// TODO(andydavis) Add checks for fusion-preventing dependences and ordering
// constraints which would prevent fusion.
// TODO(andydavis) This check if overly conservative for now. Support fusing
// statements with compatible dependences (i.e. statements where the
// dependence between the statements does not reverse direction when the
// statements are fused into the same loop).
if (!checkMemrefAccessDependence(candidate.accessA, candidate.accessB)) {
// Current conservatinve test policy: No dependence exists between accesses
// in different loop nests -> fuse loops.
fuseLoops(candidate);
}
return success();
}

View File

@ -0,0 +1,191 @@
// RUN: mlir-opt %s -loop-fusion | FileCheck %s
// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0 * 2 + 2)
// CHECK: [[MAP1:#map[0-9]+]] = (d0) -> (d0 * 3 + 1)
// CHECK: [[MAP2:#map[0-9]+]] = (d0) -> (d0 * 2)
// CHECK: [[MAP3:#map[0-9]+]] = (d0) -> (d0 * 2 + 1)
// CHECK: [[MAP4:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * 2 - d1 - s0 * 7 + 3, d0 * 9 + d1 * 3 + s1 * 13 - 10)
// CHECK: [[MAP5:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * 2 - 1, d1 * 3 + s0 * 2 + s1 * 3)
// CHECK: [[MAP6:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * 2 - 1, d1 * 3 + s0 + s1 * 3)
// The dependence check for this test builds the following set of constraints,
// where the equality contraint equates the two accesses to the memref (from
// different loops), and the inequality constraints represent the upper and
// lower bounds for each loop. After elimination, this linear system can be
// shown to be non-empty (i.e. x0 = x1 = 1 is a solution). As such, the
// dependence check between accesses in the two loops will return true, and
// the loops (according to the current test loop fusion algorithm) should not be
// fused.
//
// x0 x1 x2
// 2 -3 1 = 0
// 1 0 0 >= 0
// -1 0 100 >= 0
// 0 1 0 >= 0
// 0 -1 100 >= 0
//
// CHECK-LABEL: mlfunc @loop_fusion_1d_should_not_fuse_loops() {
mlfunc @loop_fusion_1d_should_not_fuse_loops() {
%m = alloc() : memref<100xf32, (d0) -> (d0)>
// Check that the first loop remains unfused.
// CHECK: for %i0 = 0 to 100 {
// CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP0]](%i0)
// CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]{{\]}}
// CHECK-NEXT: }
for %i0 = 0 to 100 {
%a0 = affine_apply (d0) -> (d0 * 2 + 2) (%i0)
%c1 = constant 1.0 : f32
store %c1, %m[%a0] : memref<100xf32, (d0) -> (d0)>
}
// Check that the second loop remains unfused.
// CHECK: for %i1 = 0 to 100 {
// CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP1]](%i1)
// CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]{{\]}}
// CHECK-NEXT: }
for %i1 = 0 to 100 {
%a1 = affine_apply (d0) -> (d0 * 3 + 1) (%i1)
%v0 = load %m[%a1] : memref<100xf32, (d0) -> (d0)>
}
return
}
// The dependence check for this test builds the following set of constraints:
//
// x0 x1 x2
// 2 -2 -1 = 0
// 1 0 0 >= 0
// -1 0 100 >= 0
// 0 1 0 >= 0
// 0 -1 100 >= 0
//
// After elimination, this linear system can be shown to have no solutions, and
// so no dependence exists and the loops should be fused in this test (according
// to the current trivial test loop fusion policy).
//
//
// CHECK-LABEL: mlfunc @loop_fusion_1d_should_fuse_loops() {
mlfunc @loop_fusion_1d_should_fuse_loops() {
%m = alloc() : memref<100xf32, (d0) -> (d0)>
// Should fuse statements from the second loop into the first loop.
// CHECK: for %i0 = 0 to 100 {
// CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP2]](%i0)
// CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]{{\]}}
// CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP3]](%i0)
// CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]{{\]}}
// CHECK-NEXT: }
// CHECK-NEXT: return
for %i0 = 0 to 100 {
%a0 = affine_apply (d0) -> (d0 * 2) (%i0)
%c1 = constant 1.0 : f32
store %c1, %m[%a0] : memref<100xf32, (d0) -> (d0)>
}
for %i1 = 0 to 100 {
%a1 = affine_apply (d0) -> (d0 * 2 + 1) (%i1)
%v0 = load %m[%a1] : memref<100xf32, (d0) -> (d0)>
}
return
}
// The dependence check for this test builds the following set of
// equality constraints (one for each memref dimension). Note: inequality
// constraints for loop bounds not shown.
//
// i0 i1 i2 i3 s0 s1 s2 c
// 2 -1 -2 0 -7 0 0 4 = 0
// 9 3 0 -3 0 11 -3 -10 = 0
//
// CHECK-LABEL: mlfunc @loop_fusion_2d_should_not_fuse_loops() {
mlfunc @loop_fusion_2d_should_not_fuse_loops() {
%m = alloc() : memref<10x10xf32>
%s0 = constant 7 : index
%s1 = constant 11 : index
%s2 = constant 13 : index
// Check that the first loop remains unfused.
// CHECK: for %i0 = 0 to 100 {
// CHECK-NEXT: for %i1 = 0 to 50 {
// CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP4]](%i0, %i1)[%c7, %c11]
// CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]#0, [[I0]]#1{{\]}}
// CHECK-NEXT: }
// CHECK-NEXT: }
for %i0 = 0 to 100 {
for %i1 = 0 to 50 {
%a0 = affine_apply
(d0, d1)[s0, s1] ->
(d0 * 2 -d1 + -7 * s0 + 3 , d0 * 9 + d1 * 3 + 13 * s1 - 10)
(%i0, %i1)[%s0, %s1]
%c1 = constant 1.0 : f32
store %c1, %m[%a0#0, %a0#1] : memref<10x10xf32>
}
}
// Check that the second loop remains unfused.
// CHECK: for %i2 = 0 to 100 {
// CHECK-NEXT: for %i3 = 0 to 50 {
// CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP5]](%i2, %i3)[%c11, %c13]
// CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]#0, [[I1]]#1{{\]}}
// CHECK-NEXT: }
// CHECK-NEXT: }
for %i2 = 0 to 100 {
for %i3 = 0 to 50 {
%a1 = affine_apply
(d0, d1)[s0, s1] ->
(d0 * 2 - 1, d1 * 3 + s0 * 2 + s1 * 3) (%i2, %i3)[%s1, %s2]
%v0 = load %m[%a1#0, %a1#1] : memref<10x10xf32>
}
}
return
}
// The dependence check for this test builds the following set of
// equality constraints (one for each memref dimension). Note: inequality
// constraints for loop bounds not shown.
//
// i0 i1 i2 i3 s0 s1 s2 c
// 2 -1 -2 0 -7 0 0 4 = 0
// 9 3 0 -3 0 12 -3 -10 = 0
//
// The second equality will fail the GCD test and so the system has no solution,
// so the loops should be fused under the current test policy.
//
// CHECK-LABEL: mlfunc @loop_fusion_2d_should_fuse_loops() {
mlfunc @loop_fusion_2d_should_fuse_loops() {
%m = alloc() : memref<10x10xf32>
%s0 = constant 7 : index
%s1 = constant 11 : index
%s2 = constant 13 : index
// Should fuse statements from the second loop into the first loop.
// CHECK: for %i0 = 0 to 100 {
// CHECK-NEXT: for %i1 = 0 to 50 {
// CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP4]](%i0, %i1)[%c7, %c11]
// CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]#0, [[I0]]#1{{\]}}
// CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP6]](%i0, %i1)[%c11, %c13]
// CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]#0, [[I1]]#1{{\]}}
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
for %i0 = 0 to 100 {
for %i1 = 0 to 50 {
%a0 = affine_apply
(d0, d1)[s0, s1] ->
(d0 * 2 -d1 + -7 * s0 + 3 , d0 * 9 + d1 * 3 + 13 * s1 - 10)
(%i0, %i1)[%s0, %s1]
%c1 = constant 1.0 : f32
store %c1, %m[%a0#0, %a0#1] : memref<10x10xf32>
}
}
for %i2 = 0 to 100 {
for %i3 = 0 to 50 {
%a1 = affine_apply
(d0, d1)[s0, s1] ->
(d0 * 2 - 1, d1 * 3 + s0 + s1 * 3) (%i2, %i3)[%s1, %s2]
%v0 = load %m[%a1#0, %a1#1] : memref<10x10xf32>
}
}
return
}

View File

@ -72,6 +72,7 @@ enum Passes {
ConstantFold,
ConvertToCFG,
MemRefBoundCheck,
LoopFusion,
LoopUnroll,
LoopUnrollAndJam,
PipelineDataTransfer,
@ -94,6 +95,7 @@ static cl::list<Passes> passList(
"Convert all ML functions in the module to CFG ones"),
clEnumValN(MemRefBoundCheck, "memref-bound-check",
"Convert all ML functions in the module to CFG ones"),
clEnumValN(LoopFusion, "loop-fusion", "Fuse loop nests"),
clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
clEnumValN(LoopUnrollAndJam, "loop-unroll-jam", "Unroll and jam loops"),
clEnumValN(PipelineDataTransfer, "pipeline-data-transfer",
@ -198,6 +200,9 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) {
case MemRefBoundCheck:
pass = createMemRefBoundCheckPass();
break;
case LoopFusion:
pass = createLoopFusionPass();
break;
case LoopUnroll:
pass = createLoopUnrollPass();
break;