Define the AffineForOp and replace ForInst with it. This patch is largely mechanical, i.e. changing usages of ForInst to OpPointer<AffineForOp>. An important difference is that upon construction an AffineForOp no longer automatically creates the body and induction variable. To generate the body/iv, 'createBody' can be called on an AffineForOp with no body.

PiperOrigin-RevId: 232060516
This commit is contained in:
River Riddle 2019-02-01 16:42:18 -08:00 committed by jpienaar
parent e0774c008f
commit 5052bd8582
52 changed files with 1569 additions and 1730 deletions

View File

@ -28,13 +28,230 @@
#include "mlir/IR/StandardTypes.h"
namespace mlir {
class AffineBound;
class AffineOpsDialect : public Dialect {
public:
AffineOpsDialect(MLIRContext *context);
};
/// The "if" operation represents an ifthenelse construct for conditionally
/// The "for" instruction represents an affine loop nest, defining an SSA value
/// for its induction variable. The induction variable is represented as a
/// BlockArgument to the entry block of the body. The body and induction
/// variable can be created automatically for new "for" ops with 'createBody'.
/// This SSA value always has type index, which is the size of the machine word.
/// The stride, represented by step, is a positive constant integer which
/// defaults to "1" if not present. The lower and upper bounds specify a
/// half-open range: the range includes the lower bound but does not include the
/// upper bound.
///
/// The lower and upper bounds of a for operation are represented as an
/// application of an affine mapping to a list of SSA values passed to the map.
/// The same restrictions hold for these SSA values as for all bindings of SSA
/// values to dimensions and symbols. The affine mappings for the bounds may
/// return multiple results, in which case the max/min keywords are required
/// (for the lower/upper bound respectively), and the bound is the
/// maximum/minimum of the returned values.
///
/// Example:
///
/// for %i = 1 to 10 {
/// ...
/// }
///
class AffineForOp
: public Op<AffineForOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
// Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result,
ArrayRef<Value *> lbOperands, AffineMap lbMap,
ArrayRef<Value *> ubOperands, AffineMap ubMap,
int64_t step = 1);
static void build(Builder *builder, OperationState *result, int64_t lb,
int64_t ub, int64_t step = 1);
bool verify() const;
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
static StringRef getOperationName() { return "for"; }
static StringRef getStepAttrName() { return "step"; }
static StringRef getLowerBoundAttrName() { return "lower_bound"; }
static StringRef getUpperBoundAttrName() { return "upper_bound"; }
/// Generate a body block for this AffineForOp. The operation must not already
/// have a body. The operation must contain a parent function.
Block *createBody();
/// Get the body of the AffineForOp.
Block *getBody() { return &getBlockList().front(); }
const Block *getBody() const { return &getBlockList().front(); }
/// Get the blocklist containing the body.
BlockList &getBlockList() { return getInstruction()->getBlockList(0); }
const BlockList &getBlockList() const {
return getInstruction()->getBlockList(0);
}
/// Returns the induction variable for this loop.
Value *getInductionVar();
const Value *getInductionVar() const {
return const_cast<AffineForOp *>(this)->getInductionVar();
}
//===--------------------------------------------------------------------===//
// Bounds and step
//===--------------------------------------------------------------------===//
using operand_range = llvm::iterator_range<operand_iterator>;
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
// TODO: provide iterators for the lower and upper bound operands
// if the current access via getLowerBound(), getUpperBound() is too slow.
/// Returns operands for the lower bound map.
operand_range getLowerBoundOperands();
const_operand_range getLowerBoundOperands() const;
/// Returns operands for the upper bound map.
operand_range getUpperBoundOperands();
const_operand_range getUpperBoundOperands() const;
/// Returns information about the lower bound as a single object.
const AffineBound getLowerBound() const;
/// Returns information about the upper bound as a single object.
const AffineBound getUpperBound() const;
/// Returns loop step.
int64_t getStep() const {
return getAttr(getStepAttrName()).cast<IntegerAttr>().getInt();
}
/// Returns affine map for the lower bound.
AffineMap getLowerBoundMap() const {
return getAttr(getLowerBoundAttrName()).cast<AffineMapAttr>().getValue();
}
/// Returns affine map for the upper bound. The upper bound is exclusive.
AffineMap getUpperBoundMap() const {
return getAttr(getUpperBoundAttrName()).cast<AffineMapAttr>().getValue();
}
/// Set lower bound. The new bound must have the same number of operands as
/// the current bound map. Otherwise, 'replaceForLowerBound' should be used.
void setLowerBound(ArrayRef<Value *> operands, AffineMap map);
/// Set upper bound. The new bound must not have more operands than the
/// current bound map. Otherwise, 'replaceForUpperBound' should be used.
void setUpperBound(ArrayRef<Value *> operands, AffineMap map);
/// Set the lower bound map without changing operands.
void setLowerBoundMap(AffineMap map);
/// Set the upper bound map without changing operands.
void setUpperBoundMap(AffineMap map);
/// Set loop step.
void setStep(int64_t step) {
assert(step > 0 && "step has to be a positive integer constant");
auto *context = getLowerBoundMap().getContext();
setAttr(Identifier::get(getStepAttrName(), context),
IntegerAttr::get(IndexType::get(context), step));
}
/// Returns true if the lower bound is constant.
bool hasConstantLowerBound() const;
/// Returns true if the upper bound is constant.
bool hasConstantUpperBound() const;
/// Returns true if both bounds are constant.
bool hasConstantBounds() const {
return hasConstantLowerBound() && hasConstantUpperBound();
}
/// Returns the value of the constant lower bound.
/// Fails assertion if the bound is non-constant.
int64_t getConstantLowerBound() const;
/// Returns the value of the constant upper bound. The upper bound is
/// exclusive. Fails assertion if the bound is non-constant.
int64_t getConstantUpperBound() const;
/// Sets the lower bound to the given constant value.
void setConstantLowerBound(int64_t value);
/// Sets the upper bound to the given constant value.
void setConstantUpperBound(int64_t value);
/// Returns true if both the lower and upper bound have the same operand lists
/// (same operands in the same order).
bool matchingBoundOperandList() const;
/// Walk the operation instructions in the 'for' instruction in preorder,
/// calling the callback for each operation.
void walkOps(std::function<void(OperationInst *)> callback);
/// Walk the operation instructions in the 'for' instruction in postorder,
/// calling the callback for each operation.
void walkOpsPostOrder(std::function<void(OperationInst *)> callback);
private:
friend class OperationInst;
explicit AffineForOp(const OperationInst *state) : Op(state) {}
};
/// Returns if the provided value is the induction variable of a AffineForOp.
bool isForInductionVar(const Value *val);
/// Returns the loop parent of an induction variable. If the provided value is
/// not an induction variable, then return nullptr.
OpPointer<AffineForOp> getForInductionVarOwner(Value *val);
ConstOpPointer<AffineForOp> getForInductionVarOwner(const Value *val);
/// Extracts the induction variables from a list of AffineForOps and returns
/// them.
SmallVector<Value *, 8>
extractForInductionVars(MutableArrayRef<OpPointer<AffineForOp>> forInsts);
/// AffineBound represents a lower or upper bound in the for instruction.
/// This class does not own the underlying operands. Instead, it refers
/// to the operands stored in the AffineForOp. Its life span should not exceed
/// that of the for instruction it refers to.
class AffineBound {
public:
ConstOpPointer<AffineForOp> getAffineForOp() const { return inst; }
AffineMap getMap() const { return map; }
unsigned getNumOperands() const { return opEnd - opStart; }
const Value *getOperand(unsigned idx) const {
return inst->getInstruction()->getOperand(opStart + idx);
}
using operand_iterator = AffineForOp::operand_iterator;
using operand_range = AffineForOp::operand_range;
operand_iterator operand_begin() const {
return const_cast<OperationInst *>(inst->getInstruction())
->operand_begin() +
opStart;
}
operand_iterator operand_end() const {
return const_cast<OperationInst *>(inst->getInstruction())
->operand_begin() +
opEnd;
}
operand_range getOperands() const { return {operand_begin(), operand_end()}; }
private:
// 'for' instruction that contains this bound.
ConstOpPointer<AffineForOp> inst;
// Start and end positions of this affine bound operands in the list of
// the containing 'for' instruction operands.
unsigned opStart, opEnd;
// Affine map for this bound.
AffineMap map;
AffineBound(ConstOpPointer<AffineForOp> inst, unsigned opStart,
unsigned opEnd, AffineMap map)
: inst(inst), opStart(opStart), opEnd(opEnd), map(map) {}
friend class AffineForOp;
};
/// The "if" operation represents an if-then-else construct for conditionally
/// executing two regions of code. The operands to an if operation are an
/// IntegerSet condition and a set of symbol/dimension operands to the
/// condition set. The operation produces no results. For example:

View File

@ -32,10 +32,10 @@ namespace mlir {
class AffineApplyOp;
class AffineExpr;
class AffineForOp;
class AffineMap;
class AffineValueMap;
class FlatAffineConstraints;
class ForInst;
class FuncBuilder;
class Instruction;
class IntegerSet;
@ -108,12 +108,12 @@ bool getFlattenedAffineExprs(
FlatAffineConstraints *cst = nullptr);
/// Builds a system of constraints with dimensional identifiers corresponding to
/// the loop IVs of the forInsts appearing in that order. Bounds of the loop are
/// the loop IVs of the forOps appearing in that order. Bounds of the loop are
/// used to add appropriate inequalities. Any symbols founds in the bound
/// operands are added as symbols in the system. Returns false for the yet
/// unimplemented cases.
// TODO(bondhugula): handle non-unit strides.
bool getIndexSet(llvm::ArrayRef<ForInst *> forInsts,
bool getIndexSet(llvm::MutableArrayRef<OpPointer<AffineForOp>> forOps,
FlatAffineConstraints *domain);
/// Encapsulates a memref load or store access information.

View File

@ -28,9 +28,10 @@ namespace mlir {
class AffineApplyOp;
class AffineBound;
class AffineForOp;
class AffineCondition;
class AffineMap;
class ForInst;
template <typename T> class ConstOpPointer;
class IntegerSet;
class MLIRContext;
class Value;
@ -113,13 +114,12 @@ private:
/// results, and its map can themselves change as a result of
/// substitutions, simplifications, and other analysis.
// An affine value map can readily be constructed from an AffineApplyOp, or an
// AffineBound of a ForInst. It can be further transformed, substituted into,
// or simplified. Unlike AffineMap's, AffineValueMap's are created and destroyed
// during analysis. Only the AffineMap expressions that are pointed by them are
// unique'd.
// An affine value map, and the operations on it, maintain the invariant that
// operands are always positionally aligned with the AffineDimExpr and
// AffineSymbolExpr in the underlying AffineMap.
// AffineBound of a AffineForOp. It can be further transformed, substituted
// into, or simplified. Unlike AffineMap's, AffineValueMap's are created and
// destroyed during analysis. Only the AffineMap expressions that are pointed by
// them are unique'd. An affine value map, and the operations on it, maintain
// the invariant that operands are always positionally aligned with the
// AffineDimExpr and AffineSymbolExpr in the underlying AffineMap.
// TODO(bondhugula): Some of these classes could go into separate files.
class AffineValueMap {
public:
@ -173,9 +173,6 @@ private:
// Both, the integer set being pointed to and the operands can change during
// analysis, simplification, and transformation.
class IntegerValueSet {
// Constructs an integer value set map from an IntegerSet and operands.
explicit IntegerValueSet(const AffineCondition &cond);
/// Constructs an integer value set from an affine value map.
// This will lead to a single equality in 'set'.
explicit IntegerValueSet(const AffineValueMap &avm);
@ -403,7 +400,7 @@ public:
/// Adds constraints (lower and upper bounds) for the specified 'for'
/// instruction's Value using IR information stored in its bound maps. The
/// right identifier is first looked up using forInst's Value. Returns
/// right identifier is first looked up using forOp's Value. Returns
/// false for the yet unimplemented/unsupported cases, and true if the
/// information is succesfully added. Asserts if the Value corresponding to
/// the 'for' instruction isn't found in the constraint system. Any new
@ -411,7 +408,7 @@ public:
/// are added as trailing identifiers (either dimensional or symbolic
/// depending on whether the operand is a valid ML Function symbol).
// TODO(bondhugula): add support for non-unit strides.
bool addForInstDomain(const ForInst &forInst);
bool addAffineForOpDomain(ConstOpPointer<AffineForOp> forOp);
/// Adds a constant lower bound constraint for the specified expression.
void addConstantLowerBound(ArrayRef<int64_t> expr, int64_t lb);

View File

@ -29,8 +29,9 @@
namespace mlir {
class AffineExpr;
class AffineForOp;
class AffineMap;
class ForInst;
template <typename T> class ConstOpPointer;
class MemRefType;
class OperationInst;
class Value;
@ -38,19 +39,20 @@ class Value;
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
AffineExpr getTripCountExpr(const ForInst &forInst);
AffineExpr getTripCountExpr(ConstOpPointer<AffineForOp> forOp);
/// Returns the trip count of the loop if it's a constant, None otherwise. This
/// uses affine expression analysis and is able to determine constant trip count
/// in non-trivial cases.
llvm::Optional<uint64_t> getConstantTripCount(const ForInst &forInst);
llvm::Optional<uint64_t>
getConstantTripCount(ConstOpPointer<AffineForOp> forOp);
/// Returns the greatest known integral divisor of the trip count. Affine
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
uint64_t getLargestDivisorOfTripCount(const ForInst &forInst);
uint64_t getLargestDivisorOfTripCount(ConstOpPointer<AffineForOp> forOp);
/// Given an induction variable `iv` of type ForInst and an `index` of type
/// Given an induction variable `iv` of type AffineForOp and an `index` of type
/// IndexType, returns `true` if `index` is independent of `iv` and false
/// otherwise.
/// The determination supports composition with at most one AffineApplyOp.
@ -67,7 +69,7 @@ uint64_t getLargestDivisorOfTripCount(const ForInst &forInst);
/// conservative.
bool isAccessInvariant(const Value &iv, const Value &index);
/// Given an induction variable `iv` of type ForInst and `indices` of type
/// Given an induction variable `iv` of type AffineForOp and `indices` of type
/// IndexType, returns the set of `indices` that are independent of `iv`.
///
/// Prerequisites (inherited from `isAccessInvariant` above):
@ -85,21 +87,21 @@ getInvariantAccesses(const Value &iv, llvm::ArrayRef<const Value *> indices);
/// 3. all nested load/stores are to scalar MemRefs.
/// TODO(ntv): implement dependence semantics
/// TODO(ntv): relax the no-conditionals restriction
bool isVectorizableLoop(const ForInst &loop);
bool isVectorizableLoop(ConstOpPointer<AffineForOp> loop);
/// Checks whether the loop is structurally vectorizable and that all the LoadOp
/// and StoreOp matched have access indexing functions that are are either:
/// 1. invariant along the loop induction variable created by 'loop';
/// 2. varying along the 'fastestVaryingDim' memory dimension.
bool isVectorizableLoopAlongFastestVaryingMemRefDim(const ForInst &loop,
unsigned fastestVaryingDim);
bool isVectorizableLoopAlongFastestVaryingMemRefDim(
ConstOpPointer<AffineForOp> loop, unsigned fastestVaryingDim);
/// Checks where SSA dominance would be violated if a for inst's body
/// instructions are shifted by the specified shifts. This method checks if a
/// 'def' and all its uses have the same shift factor.
// TODO(mlir-team): extend this to check for memory-based dependence
// violation when we have the support.
bool isInstwiseShiftValid(const ForInst &forInst,
bool isInstwiseShiftValid(ConstOpPointer<AffineForOp> forOp,
llvm::ArrayRef<uint64_t> shifts);
} // end namespace mlir

View File

@ -127,7 +127,6 @@ private:
struct State : public InstWalker<State> {
State(NestedPattern &pattern, SmallVectorImpl<NestedMatch> *matches)
: pattern(pattern), matches(matches) {}
void visitForInst(ForInst *forInst) { pattern.matchOne(forInst, matches); }
void visitOperationInst(OperationInst *opInst) {
pattern.matchOne(opInst, matches);
}

View File

@ -33,10 +33,12 @@
namespace mlir {
class AffineForOp;
template <typename T> class ConstOpPointer;
class FlatAffineConstraints;
class ForInst;
class MemRefAccess;
class OperationInst;
template <typename T> class OpPointer;
class Instruction;
class Value;
@ -49,7 +51,8 @@ bool properlyDominates(const Instruction &a, const Instruction &b);
/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from
/// the outermost 'for' instruction to the innermost one.
// TODO(bondhugula): handle 'if' inst's.
void getLoopIVs(const Instruction &inst, SmallVectorImpl<ForInst *> *loops);
void getLoopIVs(const Instruction &inst,
SmallVectorImpl<OpPointer<AffineForOp>> *loops);
/// Returns the nesting depth of this instruction, i.e., the number of loops
/// surrounding this instruction.
@ -191,12 +194,12 @@ bool getBackwardComputationSliceState(const MemRefAccess &srcAccess,
// materialize the results of the backward slice - presenting a trade-off b/w
// storage and redundant computation in several cases.
// TODO(andydavis) Support computation slices with common surrounding loops.
ForInst *insertBackwardComputationSlice(OperationInst *srcOpInst,
OperationInst *dstOpInst,
unsigned dstLoopDepth,
ComputationSliceState *sliceState);
OpPointer<AffineForOp>
insertBackwardComputationSlice(OperationInst *srcOpInst,
OperationInst *dstOpInst, unsigned dstLoopDepth,
ComputationSliceState *sliceState);
Optional<int64_t> getMemoryFootprintBytes(const ForInst &forInst,
Optional<int64_t> getMemoryFootprintBytes(ConstOpPointer<AffineForOp> forOp,
int memorySpace = -1);
} // end namespace mlir

View File

@ -25,8 +25,8 @@
namespace mlir {
class AffineApplyOp;
class AffineForOp;
class AffineMap;
class ForInst;
class FuncBuilder;
class Instruction;
class Location;
@ -71,7 +71,7 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType);
/// loop information is extracted.
///
/// Prerequisites: `opInst` is a vectorizable load or store operation (i.e. at
/// most one invariant index along each ForInst of `loopToVectorDim`).
/// most one invariant index along each AffineForOp of `loopToVectorDim`).
///
/// Example 1:
/// The following MLIR snippet:
@ -122,9 +122,9 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType);
/// Meaning that vector_transfer_read will be responsible of reading the slice
/// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
///
AffineMap
makePermutationMap(OperationInst *opInst,
const llvm::DenseMap<ForInst *, unsigned> &loopToVectorDim);
AffineMap makePermutationMap(
OperationInst *opInst,
const llvm::DenseMap<Instruction *, unsigned> &loopToVectorDim);
namespace matcher {

View File

@ -239,11 +239,6 @@ public:
/// current function.
Block *createBlock(Block *insertBefore = nullptr);
/// Returns a builder for the body of a 'for' instruction.
static FuncBuilder getForInstBodyBuilder(ForInst *forInst) {
return FuncBuilder(forInst->getBody(), forInst->getBody()->end());
}
/// Returns the current block of the builder.
Block *getBlock() const { return block; }
@ -277,15 +272,6 @@ public:
return cloneInst;
}
// Creates a for instruction. When step is not specified, it is set to 1.
ForInst *createFor(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step = 1);
// Creates a for instruction with known (constant) lower and upper bounds.
// Default step is 1.
ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1);
private:
Function *function;
Block *block = nullptr;

View File

@ -83,8 +83,6 @@ public:
"Must pass the derived type to this template!");
switch (s->getKind()) {
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->visitForInst(cast<ForInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->visitOperationInst(
cast<OperationInst>(s));
@ -101,7 +99,6 @@ public:
// When visiting a for inst, if inst, or an operation inst directly, these
// methods get called to indicate when transitioning into a new unit.
void visitForInst(ForInst *forInst) {}
void visitOperationInst(OperationInst *opInst) {}
};
@ -147,22 +144,11 @@ public:
void walkOpInstPostOrder(OperationInst *opInst) {
for (auto &blockList : opInst->getBlockLists())
for (auto &block : blockList)
static_cast<SubClass *>(this)->walk(block.begin(), block.end());
static_cast<SubClass *>(this)->walkPostOrder(block.begin(),
block.end());
static_cast<SubClass *>(this)->visitOperationInst(opInst);
}
void walkForInst(ForInst *forInst) {
static_cast<SubClass *>(this)->visitForInst(forInst);
auto *body = forInst->getBody();
static_cast<SubClass *>(this)->walk(body->begin(), body->end());
}
void walkForInstPostOrder(ForInst *forInst) {
auto *body = forInst->getBody();
static_cast<SubClass *>(this)->walkPostOrder(body->begin(), body->end());
static_cast<SubClass *>(this)->visitForInst(forInst);
}
// Function to walk a instruction.
RetTy walk(Instruction *s) {
static_assert(std::is_base_of<InstWalker, SubClass>::value,
@ -171,8 +157,6 @@ public:
static_cast<SubClass *>(this)->visitInstruction(s);
switch (s->getKind()) {
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->walkForInst(cast<ForInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->walkOpInst(cast<OperationInst>(s));
}
@ -185,9 +169,6 @@ public:
static_cast<SubClass *>(this)->visitInstruction(s);
switch (s->getKind()) {
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->walkForInstPostOrder(
cast<ForInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->walkOpInstPostOrder(
cast<OperationInst>(s));
@ -205,7 +186,6 @@ public:
// called. These are typically O(1) complexity and shouldn't be recursively
// processing their descendants in some way. When using RetTy, all of these
// need to be overridden.
void visitForInst(ForInst *forInst) {}
void visitOperationInst(OperationInst *opInst) {}
void visitInstruction(Instruction *inst) {}
};

View File

@ -32,7 +32,6 @@ namespace mlir {
class Block;
class BlockAndValueMapping;
class Location;
class ForInst;
class MLIRContext;
/// Terminator operations can have Block operands to represent successors.
@ -74,7 +73,6 @@ class Instruction : public IROperandOwner,
public:
enum class Kind {
OperationInst = (int)IROperandOwner::Kind::OperationInst,
For = (int)IROperandOwner::Kind::ForInst,
};
Kind getKind() const { return (Kind)IROperandOwner::getKind(); }

View File

@ -26,15 +26,11 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OperationSupport.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
class AffineBound;
class IntegerSet;
class AffineCondition;
class AttributeListStorage;
template <typename OpType> class ConstOpPointer;
template <typename OpType> class OpPointer;
@ -219,6 +215,13 @@ public:
return getOperandStorage().isResizable();
}
/// Replace the current operands of this operation with the ones provided in
/// 'operands'. If the operands list is not resizable, the size of 'operands'
/// must be less than or equal to the current number of operands.
void setOperands(ArrayRef<Value *> operands) {
getOperandStorage().setOperands(this, operands);
}
unsigned getNumOperands() const { return getOperandStorage().size(); }
Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
@ -697,262 +700,6 @@ inline auto OperationInst::getResultTypes() const
return {result_type_begin(), result_type_end()};
}
/// For instruction represents an affine loop nest.
class ForInst final
: public Instruction,
private llvm::TrailingObjects<ForInst, detail::OperandStorage> {
public:
static ForInst *create(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step);
/// Resolve base class ambiguity.
using Instruction::getFunction;
/// Operand iterators.
using operand_iterator = OperandIterator<ForInst, Value>;
using const_operand_iterator = OperandIterator<const ForInst, const Value>;
/// Operand iterator range.
using operand_range = llvm::iterator_range<operand_iterator>;
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
/// Get the body of the ForInst.
Block *getBody() { return &body.front(); }
/// Get the body of the ForInst.
const Block *getBody() const { return &body.front(); }
//===--------------------------------------------------------------------===//
// Bounds and step
//===--------------------------------------------------------------------===//
/// Returns information about the lower bound as a single object.
const AffineBound getLowerBound() const;
/// Returns information about the upper bound as a single object.
const AffineBound getUpperBound() const;
/// Returns loop step.
int64_t getStep() const { return step; }
/// Returns affine map for the lower bound.
AffineMap getLowerBoundMap() const { return lbMap; }
/// Returns affine map for the upper bound. The upper bound is exclusive.
AffineMap getUpperBoundMap() const { return ubMap; }
/// Set lower bound.
void setLowerBound(ArrayRef<Value *> operands, AffineMap map);
/// Set upper bound.
void setUpperBound(ArrayRef<Value *> operands, AffineMap map);
/// Set the lower bound map without changing operands.
void setLowerBoundMap(AffineMap map);
/// Set the upper bound map without changing operands.
void setUpperBoundMap(AffineMap map);
/// Set loop step.
void setStep(int64_t step) {
assert(step > 0 && "step has to be a positive integer constant");
this->step = step;
}
/// Returns true if the lower bound is constant.
bool hasConstantLowerBound() const;
/// Returns true if the upper bound is constant.
bool hasConstantUpperBound() const;
/// Returns true if both bounds are constant.
bool hasConstantBounds() const {
return hasConstantLowerBound() && hasConstantUpperBound();
}
/// Returns the value of the constant lower bound.
/// Fails assertion if the bound is non-constant.
int64_t getConstantLowerBound() const;
/// Returns the value of the constant upper bound. The upper bound is
/// exclusive. Fails assertion if the bound is non-constant.
int64_t getConstantUpperBound() const;
/// Sets the lower bound to the given constant value.
void setConstantLowerBound(int64_t value);
/// Sets the upper bound to the given constant value.
void setConstantUpperBound(int64_t value);
/// Returns true if both the lower and upper bound have the same operand lists
/// (same operands in the same order).
bool matchingBoundOperandList() const;
/// Walk the operation instructions in the 'for' instruction in preorder,
/// calling the callback for each operation.
void walkOps(std::function<void(OperationInst *)> callback);
/// Walk the operation instructions in the 'for' instruction in postorder,
/// calling the callback for each operation.
void walkOpsPostOrder(std::function<void(OperationInst *)> callback);
//===--------------------------------------------------------------------===//
// Operands
//===--------------------------------------------------------------------===//
unsigned getNumOperands() const { return getOperandStorage().size(); }
Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
const Value *getOperand(unsigned idx) const {
return getInstOperand(idx).get();
}
void setOperand(unsigned idx, Value *value) {
getInstOperand(idx).set(value);
}
operand_iterator operand_begin() { return operand_iterator(this, 0); }
operand_iterator operand_end() {
return operand_iterator(this, getNumOperands());
}
const_operand_iterator operand_begin() const {
return const_operand_iterator(this, 0);
}
const_operand_iterator operand_end() const {
return const_operand_iterator(this, getNumOperands());
}
ArrayRef<InstOperand> getInstOperands() const {
return getOperandStorage().getInstOperands();
}
MutableArrayRef<InstOperand> getInstOperands() {
return getOperandStorage().getInstOperands();
}
InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; }
const InstOperand &getInstOperand(unsigned idx) const {
return getInstOperands()[idx];
}
// TODO: provide iterators for the lower and upper bound operands
// if the current access via getLowerBound(), getUpperBound() is too slow.
/// Returns operands for the lower bound map.
operand_range getLowerBoundOperands();
const_operand_range getLowerBoundOperands() const;
/// Returns operands for the upper bound map.
operand_range getUpperBoundOperands();
const_operand_range getUpperBoundOperands() const;
//===--------------------------------------------------------------------===//
// Other
//===--------------------------------------------------------------------===//
/// Return the context this operation is associated with.
MLIRContext *getContext() const {
return getInductionVar()->getType().getContext();
}
using Instruction::dump;
using Instruction::print;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const IROperandOwner *ptr) {
return ptr->getKind() == IROperandOwner::Kind::ForInst;
}
/// Returns the induction variable for this loop.
Value *getInductionVar();
const Value *getInductionVar() const {
return const_cast<ForInst *>(this)->getInductionVar();
}
void destroy();
private:
// The Block for the body. By construction, this list always contains exactly
// one block.
BlockList body;
// Affine map for the lower bound.
AffineMap lbMap;
// Affine map for the upper bound. The upper bound is exclusive.
AffineMap ubMap;
// Positive constant step. Since index is stored as an int64_t, we restrict
// step to the set of positive integers that int64_t can represent.
int64_t step;
explicit ForInst(Location location, AffineMap lbMap, AffineMap ubMap,
int64_t step);
~ForInst();
/// Returns the operand storage object.
detail::OperandStorage &getOperandStorage() {
return *getTrailingObjects<detail::OperandStorage>();
}
const detail::OperandStorage &getOperandStorage() const {
return *getTrailingObjects<detail::OperandStorage>();
}
// This stuff is used by the TrailingObjects template.
friend llvm::TrailingObjects<ForInst, detail::OperandStorage>;
};
/// Returns if the provided value is the induction variable of a ForInst.
bool isForInductionVar(const Value *val);
/// Returns the loop parent of an induction variable. If the provided value is
/// not an induction variable, then return nullptr.
ForInst *getForInductionVarOwner(Value *val);
const ForInst *getForInductionVarOwner(const Value *val);
/// Extracts the induction variables from a list of ForInsts and returns them.
SmallVector<Value *, 8> extractForInductionVars(ArrayRef<ForInst *> forInsts);
/// AffineBound represents a lower or upper bound in the for instruction.
/// This class does not own the underlying operands. Instead, it refers
/// to the operands stored in the ForInst. Its life span should not exceed
/// that of the for instruction it refers to.
class AffineBound {
public:
const ForInst *getForInst() const { return &inst; }
AffineMap getMap() const { return map; }
unsigned getNumOperands() const { return opEnd - opStart; }
const Value *getOperand(unsigned idx) const {
return inst.getOperand(opStart + idx);
}
const InstOperand &getInstOperand(unsigned idx) const {
return inst.getInstOperand(opStart + idx);
}
using operand_iterator = ForInst::operand_iterator;
using operand_range = ForInst::operand_range;
operand_iterator operand_begin() const {
// These are iterators over Value *. Not casting away const'ness would
// require the caller to use const Value *.
return operand_iterator(const_cast<ForInst *>(&inst), opStart);
}
operand_iterator operand_end() const {
return operand_iterator(const_cast<ForInst *>(&inst), opEnd);
}
/// Returns an iterator on the underlying Value's (Value *).
operand_range getOperands() const { return {operand_begin(), operand_end()}; }
ArrayRef<InstOperand> getInstOperands() const {
auto ops = inst.getInstOperands();
return ArrayRef<InstOperand>(ops.begin() + opStart, ops.begin() + opEnd);
}
private:
// 'for' instruction that contains this bound.
const ForInst &inst;
// Start and end positions of this affine bound operands in the list of
// the containing 'for' instruction operands.
unsigned opStart, opEnd;
// Affine map for this bound.
AffineMap map;
AffineBound(const ForInst &inst, unsigned opStart, unsigned opEnd,
AffineMap map)
: inst(inst), opStart(opStart), opEnd(opEnd), map(map) {}
friend class ForInst;
};
} // end namespace mlir
#endif // MLIR_IR_INSTRUCTIONS_H

View File

@ -68,6 +68,11 @@ public:
operator bool() const { return value.getInstruction(); }
bool operator==(OpPointer rhs) const {
return value.getInstruction() == rhs.value.getInstruction();
}
bool operator!=(OpPointer rhs) const { return !(*this == rhs); }
/// OpPointer can be implicitly converted to OpType*.
/// Return `nullptr` if there is no associated OperationInst*.
operator OpType *() {
@ -87,6 +92,9 @@ public:
private:
OpType value;
// Allow access to value to enable constructing an empty ConstOpPointer.
friend class ConstOpPointer<OpType>;
};
/// This pointer represents a notional "const OperationInst*" but where the
@ -96,6 +104,7 @@ class ConstOpPointer {
public:
explicit ConstOpPointer() : value(OperationInst::getNull<OpType>().value) {}
explicit ConstOpPointer(OpType value) : value(value) {}
ConstOpPointer(OpPointer<OpType> pointer) : value(pointer.value) {}
const OpType &operator*() const { return value; }
@ -104,6 +113,11 @@ public:
/// Return true if non-null.
operator bool() const { return value.getInstruction(); }
bool operator==(ConstOpPointer rhs) const {
return value.getInstruction() == rhs.value.getInstruction();
}
bool operator!=(ConstOpPointer rhs) const { return !(*this == rhs); }
/// ConstOpPointer can always be implicitly converted to const OpType*.
/// Return `nullptr` if there is no associated OperationInst*.
operator const OpType *() const {

View File

@ -90,7 +90,8 @@ public:
virtual void printGenericOp(const OperationInst *op) = 0;
/// Prints a block list.
virtual void printBlockList(const BlockList &blocks) = 0;
virtual void printBlockList(const BlockList &blocks,
bool printEntryBlockArgs = true) = 0;
private:
OpAsmPrinter(const OpAsmPrinter &) = delete;
@ -170,6 +171,9 @@ public:
/// This parses... a comma!
virtual bool parseComma() = 0;
/// This parses an equal(=) token!
virtual bool parseEqual() = 0;
/// Parse a type.
virtual bool parseType(Type &result) = 0;
@ -203,9 +207,9 @@ public:
}
/// Parse a keyword.
bool parseKeyword(const char *keyword) {
bool parseKeyword(const char *keyword, const Twine &msg = "") {
if (parseOptionalKeyword(keyword))
return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'");
return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'" + msg);
return false;
}
@ -315,6 +319,10 @@ public:
/// operation's block lists after the operation is created.
virtual bool parseBlockList() = 0;
/// Parses an argument for the entry block of the next block list to be
/// parsed.
virtual bool parseBlockListEntryBlockArgument(Type argType) = 0;
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
//===--------------------------------------------------------------------===//

View File

@ -75,15 +75,14 @@ private:
};
/// Subclasses of IROperandOwner can be the owner of an IROperand. In practice
/// this is the common base between Instruction and Instruction.
/// this is the common base between Instructions.
class IROperandOwner {
public:
enum class Kind {
OperationInst,
ForInst,
/// These enums define ranges used for classof implementations.
INST_LAST = ForInst,
INST_LAST = OperationInst,
};
Kind getKind() const { return locationAndKind.getInt(); }

View File

@ -27,11 +27,12 @@
#include "mlir/Support/LLVM.h"
namespace mlir {
class AffineMap;
class ForInst;
class AffineForOp;
template <typename T> class ConstOpPointer;
class Function;
class FuncBuilder;
template <typename T> class OpPointer;
// Values that can be used to signal success/failure. This can be implicitly
// converted to/from boolean values, with false representing success and true
@ -44,51 +45,54 @@ struct LLVM_NODISCARD UtilResult {
/// Unrolls this for instruction completely if the trip count is known to be
/// constant. Returns false otherwise.
bool loopUnrollFull(ForInst *forInst);
bool loopUnrollFull(OpPointer<AffineForOp> forOp);
/// Unrolls this for instruction by the specified unroll factor. Returns false
/// if the loop cannot be unrolled either due to restrictions or due to invalid
/// unroll factors.
bool loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor);
bool loopUnrollByFactor(OpPointer<AffineForOp> forOp, uint64_t unrollFactor);
/// Unrolls this loop by the specified unroll factor or its trip count,
/// whichever is lower.
bool loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor);
bool loopUnrollUpToFactor(OpPointer<AffineForOp> forOp, uint64_t unrollFactor);
/// Unrolls and jams this loop by the specified factor. Returns true if the loop
/// is successfully unroll-jammed.
bool loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor);
bool loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
uint64_t unrollJamFactor);
/// Unrolls and jams this loop by the specified factor or by the trip count (if
/// constant), whichever is lower.
bool loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor);
bool loopUnrollJamUpToFactor(OpPointer<AffineForOp> forOp,
uint64_t unrollJamFactor);
/// Promotes the loop body of a ForInst to its containing block if the ForInst
/// was known to have a single iteration. Returns false otherwise.
bool promoteIfSingleIteration(ForInst *forInst);
/// Promotes the loop body of a AffineForOp to its containing block if the
/// AffineForOp was known to have a single iteration. Returns false otherwise.
bool promoteIfSingleIteration(OpPointer<AffineForOp> forOp);
/// Promotes all single iteration ForInst's in the Function, i.e., moves
/// Promotes all single iteration AffineForOp's in the Function, i.e., moves
/// their body into the containing Block.
void promoteSingleIterationLoops(Function *f);
/// Returns the lower bound of the cleanup loop when unrolling a loop
/// with the specified unroll factor.
AffineMap getCleanupLoopLowerBound(const ForInst &forInst,
AffineMap getCleanupLoopLowerBound(ConstOpPointer<AffineForOp> forOp,
unsigned unrollFactor, FuncBuilder *builder);
/// Returns the upper bound of an unrolled loop when unrolling with
/// the specified trip count, stride, and unroll factor.
AffineMap getUnrolledLoopUpperBound(const ForInst &forInst,
AffineMap getUnrolledLoopUpperBound(ConstOpPointer<AffineForOp> forOp,
unsigned unrollFactor,
FuncBuilder *builder);
/// Skew the instructions in the body of a 'for' instruction with the specified
/// instruction-wise shifts. The shifts are with respect to the original
/// execution order, and are multiplied by the loop 'step' before being applied.
UtilResult instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
UtilResult instBodySkew(OpPointer<AffineForOp> forOp, ArrayRef<uint64_t> shifts,
bool unrollPrologueEpilogue = false);
/// Tiles the specified band of perfectly nested loops creating tile-space loops
/// and intra-tile loops. A band is a contiguous set of loops.
UtilResult tileCodeGen(ArrayRef<ForInst *> band, ArrayRef<unsigned> tileSizes);
UtilResult tileCodeGen(MutableArrayRef<OpPointer<AffineForOp>> band,
ArrayRef<unsigned> tileSizes);
} // end namespace mlir

View File

@ -27,7 +27,8 @@
namespace mlir {
class ForInst;
class AffineForOp;
template <typename T> class ConstOpPointer;
class FunctionPass;
class ModulePass;
@ -57,9 +58,10 @@ FunctionPass *createMaterializeVectorsPass();
/// factors supplied through other means. If -1 is passed as the unrollFactor
/// and no callback is provided, anything passed from the command-line (if at
/// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor).
FunctionPass *createLoopUnrollPass(
int unrollFactor = -1, int unrollFull = -1,
const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr);
FunctionPass *
createLoopUnrollPass(int unrollFactor = -1, int unrollFull = -1,
const std::function<unsigned(ConstOpPointer<AffineForOp>)>
&getUnrollFactor = nullptr);
/// Creates a loop unroll jam pass to unroll jam by the specified factor. A
/// factor of -1 lets the pass use the default factor or the one on the command

View File

@ -32,7 +32,7 @@
namespace mlir {
class ForInst;
class AffineForOp;
class FuncBuilder;
class Location;
class Module;
@ -115,7 +115,7 @@ void createAffineComputationSlice(
/// Folds the lower and upper bounds of a 'for' inst to constants if possible.
/// Returns false if the folding happens for at least one bound, true otherwise.
bool constantFoldBounds(ForInst *forInst);
bool constantFoldBounds(OpPointer<AffineForOp> forInst);
/// Replaces (potentially nested) function attributes in the operation "op"
/// with those specified in "remappingTable".

View File

@ -17,7 +17,10 @@
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OpImplementation.h"
using namespace mlir;
@ -27,7 +30,445 @@ using namespace mlir;
AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
: Dialect(/*namePrefix=*/"", context) {
addOperations<AffineIfOp>();
addOperations<AffineForOp, AffineIfOp>();
}
//===----------------------------------------------------------------------===//
// AffineForOp
//===----------------------------------------------------------------------===//
void AffineForOp::build(Builder *builder, OperationState *result,
ArrayRef<Value *> lbOperands, AffineMap lbMap,
ArrayRef<Value *> ubOperands, AffineMap ubMap,
int64_t step) {
assert((!lbMap && lbOperands.empty()) ||
lbOperands.size() == lbMap.getNumInputs() &&
"lower bound operand count does not match the affine map");
assert((!ubMap && ubOperands.empty()) ||
ubOperands.size() == ubMap.getNumInputs() &&
"upper bound operand count does not match the affine map");
assert(step > 0 && "step has to be a positive integer constant");
// Add an attribute for the step.
result->addAttribute(getStepAttrName(),
builder->getIntegerAttr(builder->getIndexType(), step));
// Add the lower bound.
result->addAttribute(getLowerBoundAttrName(),
builder->getAffineMapAttr(lbMap));
result->addOperands(lbOperands);
// Add the upper bound.
result->addAttribute(getUpperBoundAttrName(),
builder->getAffineMapAttr(ubMap));
result->addOperands(ubOperands);
// Reserve a block list for the body.
result->reserveBlockLists(/*numReserved=*/1);
// Set the operands list as resizable so that we can freely modify the bounds.
result->setOperandListToResizable();
}
void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb,
int64_t ub, int64_t step) {
auto lbMap = AffineMap::getConstantMap(lb, builder->getContext());
auto ubMap = AffineMap::getConstantMap(ub, builder->getContext());
return build(builder, result, {}, lbMap, {}, ubMap, step);
}
bool AffineForOp::verify() const {
const auto &bodyBlockList = getInstruction()->getBlockList(0);
// The body block list must contain a single basic block.
if (bodyBlockList.empty() ||
std::next(bodyBlockList.begin()) != bodyBlockList.end())
return emitOpError("expected body block list to have a single block");
// Check that the body defines as single block argument for the induction
// variable.
const auto *body = getBody();
if (body->getNumArguments() != 1 ||
!body->getArgument(0)->getType().isIndex())
return emitOpError("expected body to have a single index argument for the "
"induction variable");
// TODO: check that loop bounds are properly formed.
return false;
}
/// Parse a for operation loop bounds.
static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
// 'min' / 'max' prefixes are generally syntactic sugar, but are required if
// the map has multiple results.
bool failedToParsedMinMax = p->parseOptionalKeyword(isLower ? "max" : "min");
auto &builder = p->getBuilder();
auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
: AffineForOp::getUpperBoundAttrName();
// Parse ssa-id as identity map.
SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
if (p->parseOperandList(boundOpInfos))
return true;
if (!boundOpInfos.empty()) {
// Check that only one operand was parsed.
if (boundOpInfos.size() > 1)
return p->emitError(p->getNameLoc(),
"expected only one loop bound operand");
// TODO: improve error message when SSA value is not an affine integer.
// Currently it is 'use of value ... expects different type than prior uses'
if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(),
result->operands))
return true;
// Create an identity map using symbol id. This representation is optimized
// for storage. Analysis passes may expand it into a multi-dimensional map
// if desired.
AffineMap map = builder.getSymbolIdentityMap();
result->addAttribute(boundAttrName, builder.getAffineMapAttr(map));
return false;
}
Attribute boundAttr;
if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName.data(),
result->attributes))
return true;
// Parse full form - affine map followed by dim and symbol list.
if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
unsigned currentNumOperands = result->operands.size();
unsigned numDims;
if (parseDimAndSymbolList(p, result->operands, numDims))
return true;
auto map = affineMapAttr.getValue();
if (map.getNumDims() != numDims)
return p->emitError(
p->getNameLoc(),
"dim operand count and integer set dim count must match");
unsigned numDimAndSymbolOperands =
result->operands.size() - currentNumOperands;
if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
return p->emitError(
p->getNameLoc(),
"symbol operand count and integer set symbol count must match");
// If the map has multiple results, make sure that we parsed the min/max
// prefix.
if (map.getNumResults() > 1 && failedToParsedMinMax) {
if (isLower) {
return p->emitError(p->getNameLoc(),
"lower loop bound affine map with multiple results "
"requires 'max' prefix");
}
return p->emitError(p->getNameLoc(),
"upper loop bound affine map with multiple results "
"requires 'min' prefix");
}
return false;
}
// Parse custom assembly form.
if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
result->attributes.pop_back();
result->addAttribute(
boundAttrName, builder.getAffineMapAttr(
builder.getConstantAffineMap(integerAttr.getInt())));
return false;
}
return p->emitError(
p->getNameLoc(),
"expected valid affine map representation for loop bounds");
}
bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder();
// Parse the induction variable followed by '='.
if (parser->parseBlockListEntryBlockArgument(builder.getIndexType()) ||
parser->parseEqual())
return true;
// Parse loop bounds.
if (parseBound(/*isLower=*/true, result, parser) ||
parser->parseKeyword("to", " between bounds") ||
parseBound(/*isLower=*/false, result, parser))
return true;
// Parse the optional loop step, we default to 1 if one is not present.
if (parser->parseOptionalKeyword("step")) {
result->addAttribute(
getStepAttrName(),
builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
} else {
llvm::SMLoc stepLoc;
IntegerAttr stepAttr;
if (parser->getCurrentLocation(&stepLoc) ||
parser->parseAttribute(stepAttr, builder.getIndexType(),
getStepAttrName().data(), result->attributes))
return true;
if (stepAttr.getValue().getSExtValue() < 0)
return parser->emitError(
stepLoc,
"expected step to be representable as a positive signed integer");
}
// Parse the body block list.
result->reserveBlockLists(/*numReserved=*/1);
if (parser->parseBlockList())
return true;
// Set the operands list as resizable so that we can freely modify the bounds.
result->setOperandListToResizable();
return false;
}
static void printBound(AffineBound bound, const char *prefix, OpAsmPrinter *p) {
AffineMap map = bound.getMap();
// Check if this bound should be printed using custom assembly form.
// The decision to restrict printing custom assembly form to trivial cases
// comes from the will to roundtrip MLIR binary -> text -> binary in a
// lossless way.
// Therefore, custom assembly form parsing and printing is only supported for
// zero-operand constant maps and single symbol operand identity maps.
if (map.getNumResults() == 1) {
AffineExpr expr = map.getResult(0);
// Print constant bound.
if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
*p << constExpr.getValue();
return;
}
}
// Print bound that consists of a single SSA symbol if the map is over a
// single symbol.
if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
p->printOperand(bound.getOperand(0));
return;
}
}
} else {
// Map has multiple results. Print 'min' or 'max' prefix.
*p << prefix << ' ';
}
// Print the map and its operands.
p->printAffineMap(map);
printDimAndSymbolList(bound.operand_begin(), bound.operand_end(),
map.getNumDims(), p);
}
void AffineForOp::print(OpAsmPrinter *p) const {
*p << "for ";
p->printOperand(getBody()->getArgument(0));
*p << " = ";
printBound(getLowerBound(), "max", p);
*p << " to ";
printBound(getUpperBound(), "min", p);
if (getStep() != 1)
*p << " step " << getStep();
p->printBlockList(getInstruction()->getBlockList(0),
/*printEntryBlockArgs=*/false);
}
Block *AffineForOp::createBody() {
auto &bodyBlockList = getBlockList();
assert(bodyBlockList.empty() && "expected no existing body blocks");
// Create a new block for the body, and add an argument for the induction
// variable.
Block *body = new Block();
body->addArgument(IndexType::get(getInstruction()->getContext()));
bodyBlockList.push_back(body);
return body;
}
const AffineBound AffineForOp::getLowerBound() const {
auto lbMap = getLowerBoundMap();
return AffineBound(ConstOpPointer<AffineForOp>(*this), 0,
lbMap.getNumInputs(), lbMap);
}
const AffineBound AffineForOp::getUpperBound() const {
auto lbMap = getLowerBoundMap();
auto ubMap = getUpperBoundMap();
return AffineBound(ConstOpPointer<AffineForOp>(*this), lbMap.getNumInputs(),
getNumOperands(), ubMap);
}
void AffineForOp::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
assert(lbOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end());
auto ubOperands = getUpperBoundOperands();
newOperands.append(ubOperands.begin(), ubOperands.end());
getInstruction()->setOperands(newOperands);
setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()),
AffineMapAttr::get(map));
}
void AffineForOp::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
assert(ubOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<Value *, 4> newOperands(getLowerBoundOperands());
newOperands.append(ubOperands.begin(), ubOperands.end());
getInstruction()->setOperands(newOperands);
setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()),
AffineMapAttr::get(map));
}
void AffineForOp::setLowerBoundMap(AffineMap map) {
auto lbMap = getLowerBoundMap();
assert(lbMap.getNumDims() == map.getNumDims() &&
lbMap.getNumSymbols() == map.getNumSymbols());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
(void)lbMap;
setAttr(Identifier::get(getLowerBoundAttrName(), map.getContext()),
AffineMapAttr::get(map));
}
void AffineForOp::setUpperBoundMap(AffineMap map) {
auto ubMap = getUpperBoundMap();
assert(ubMap.getNumDims() == map.getNumDims() &&
ubMap.getNumSymbols() == map.getNumSymbols());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
(void)ubMap;
setAttr(Identifier::get(getUpperBoundAttrName(), map.getContext()),
AffineMapAttr::get(map));
}
bool AffineForOp::hasConstantLowerBound() const {
return getLowerBoundMap().isSingleConstant();
}
bool AffineForOp::hasConstantUpperBound() const {
return getUpperBoundMap().isSingleConstant();
}
int64_t AffineForOp::getConstantLowerBound() const {
return getLowerBoundMap().getSingleConstantResult();
}
int64_t AffineForOp::getConstantUpperBound() const {
return getUpperBoundMap().getSingleConstantResult();
}
void AffineForOp::setConstantLowerBound(int64_t value) {
setLowerBound(
{}, AffineMap::getConstantMap(value, getInstruction()->getContext()));
}
void AffineForOp::setConstantUpperBound(int64_t value) {
setUpperBound(
{}, AffineMap::getConstantMap(value, getInstruction()->getContext()));
}
AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
}
AffineForOp::const_operand_range AffineForOp::getLowerBoundOperands() const {
return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
}
AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
}
AffineForOp::const_operand_range AffineForOp::getUpperBoundOperands() const {
return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
}
bool AffineForOp::matchingBoundOperandList() const {
auto lbMap = getLowerBoundMap();
auto ubMap = getUpperBoundMap();
if (lbMap.getNumDims() != ubMap.getNumDims() ||
lbMap.getNumSymbols() != ubMap.getNumSymbols())
return false;
unsigned numOperands = lbMap.getNumInputs();
for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
// Compare Value *'s.
if (getOperand(i) != getOperand(numOperands + i))
return false;
}
return true;
}
void AffineForOp::walkOps(std::function<void(OperationInst *)> callback) {
struct Walker : public InstWalker<Walker> {
std::function<void(OperationInst *)> const &callback;
Walker(std::function<void(OperationInst *)> const &callback)
: callback(callback) {}
void visitOperationInst(OperationInst *opInst) { callback(opInst); }
};
Walker w(callback);
w.walk(getInstruction());
}
void AffineForOp::walkOpsPostOrder(
std::function<void(OperationInst *)> callback) {
struct Walker : public InstWalker<Walker> {
std::function<void(OperationInst *)> const &callback;
Walker(std::function<void(OperationInst *)> const &callback)
: callback(callback) {}
void visitOperationInst(OperationInst *opInst) { callback(opInst); }
};
Walker v(callback);
v.walkPostOrder(getInstruction());
}
/// Returns the induction variable for this loop.
Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); }
/// Returns if the provided value is the induction variable of a AffineForOp.
bool mlir::isForInductionVar(const Value *val) {
return getForInductionVarOwner(val) != nullptr;
}
/// Returns the loop parent of an induction variable. If the provided value is
/// not an induction variable, then return nullptr.
OpPointer<AffineForOp> mlir::getForInductionVarOwner(Value *val) {
const BlockArgument *ivArg = dyn_cast<BlockArgument>(val);
if (!ivArg || !ivArg->getOwner())
return OpPointer<AffineForOp>();
auto *containingInst = ivArg->getOwner()->getParent()->getContainingInst();
if (!containingInst)
return OpPointer<AffineForOp>();
return cast<OperationInst>(containingInst)->dyn_cast<AffineForOp>();
}
ConstOpPointer<AffineForOp> mlir::getForInductionVarOwner(const Value *val) {
auto nonConstOwner = getForInductionVarOwner(const_cast<Value *>(val));
return ConstOpPointer<AffineForOp>(nonConstOwner);
}
/// Extracts the induction variables from a list of AffineForOps and returns
/// them.
SmallVector<Value *, 8> mlir::extractForInductionVars(
MutableArrayRef<OpPointer<AffineForOp>> forInsts) {
SmallVector<Value *, 8> results;
for (auto forInst : forInsts)
results.push_back(forInst->getInductionVar());
return results;
}
//===----------------------------------------------------------------------===//

View File

@ -21,12 +21,14 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
@ -519,7 +521,7 @@ void mlir::getReachableAffineApplyOps(
State &state = worklist.back();
auto *opInst = state.value->getDefiningInst();
// Note: getDefiningInst will return nullptr if the operand is not an
// OperationInst (i.e. ForInst), which is a terminator for the search.
// OperationInst (i.e. AffineForOp), which is a terminator for the search.
if (opInst == nullptr || !opInst->isa<AffineApplyOp>()) {
worklist.pop_back();
continue;
@ -546,21 +548,21 @@ void mlir::getReachableAffineApplyOps(
}
// Builds a system of constraints with dimensional identifiers corresponding to
// the loop IVs of the forInsts appearing in that order. Any symbols founds in
// the loop IVs of the forOps appearing in that order. Any symbols founds in
// the bound operands are added as symbols in the system. Returns false for the
// yet unimplemented cases.
// TODO(andydavis,bondhugula) Handle non-unit steps through local variables or
// stride information in FlatAffineConstraints. (For eg., by using iv - lb %
// step = 0 and/or by introducing a method in FlatAffineConstraints
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
bool mlir::getIndexSet(ArrayRef<ForInst *> forInsts,
bool mlir::getIndexSet(MutableArrayRef<OpPointer<AffineForOp>> forOps,
FlatAffineConstraints *domain) {
auto indices = extractForInductionVars(forInsts);
auto indices = extractForInductionVars(forOps);
// Reset while associated Values in 'indices' to the domain.
domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
for (auto *forInst : forInsts) {
// Add constraints from forInst's bounds.
if (!domain->addForInstDomain(*forInst))
domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
for (auto forOp : forOps) {
// Add constraints from forOp's bounds.
if (!domain->addAffineForOpDomain(forOp))
return false;
}
return true;
@ -576,7 +578,7 @@ static bool getInstIndexSet(const Instruction *inst,
FlatAffineConstraints *indexSet) {
// TODO(andydavis) Extend this to gather enclosing IfInsts and consider
// factoring it out into a utility function.
SmallVector<ForInst *, 4> loops;
SmallVector<OpPointer<AffineForOp>, 4> loops;
getLoopIVs(*inst, &loops);
return getIndexSet(loops, indexSet);
}
@ -998,9 +1000,9 @@ static const Block *getCommonBlock(const MemRefAccess &srcAccess,
return block;
}
auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
auto *forInst = getForInductionVarOwner(commonForValue);
assert(forInst && "commonForValue was not an induction variable");
return forInst->getBody();
auto forOp = getForInductionVarOwner(commonForValue);
assert(forOp && "commonForValue was not an induction variable");
return forOp->getBody();
}
// Returns true if the ancestor operation instruction of 'srcAccess' appears
@ -1195,7 +1197,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
// until operands of the AffineValueMap are loop IVs or symbols.
// *) Build iteration domain constraints for each access. Iteration domain
// constraints are pairs of inequality contraints representing the
// upper/lower loop bounds for each ForInst in the loop nest associated
// upper/lower loop bounds for each AffineForOp in the loop nest associated
// with each access.
// *) Build dimension and symbol position maps for each access, which map
// Values from access functions and iteration domains to their position

View File

@ -20,6 +20,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
@ -1247,22 +1248,23 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
numSymbols = newSymbolCount;
}
bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) {
bool FlatAffineConstraints::addAffineForOpDomain(
ConstOpPointer<AffineForOp> forOp) {
unsigned pos;
// Pre-condition for this method.
if (!findId(*forInst.getInductionVar(), &pos)) {
if (!findId(*forOp->getInductionVar(), &pos)) {
assert(0 && "Value not found");
return false;
}
if (forInst.getStep() != 1)
if (forOp->getStep() != 1)
LLVM_DEBUG(llvm::dbgs()
<< "Domain conservative: non-unit stride not handled\n");
// Adds a lower or upper bound when the bounds aren't constant.
auto addLowerOrUpperBound = [&](bool lower) -> bool {
auto operands = lower ? forInst.getLowerBoundOperands()
: forInst.getUpperBoundOperands();
auto operands =
lower ? forOp->getLowerBoundOperands() : forOp->getUpperBoundOperands();
for (const auto &operand : operands) {
unsigned loc;
if (!findId(*operand, &loc)) {
@ -1291,7 +1293,7 @@ bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) {
}
auto boundMap =
lower ? forInst.getLowerBoundMap() : forInst.getUpperBoundMap();
lower ? forOp->getLowerBoundMap() : forOp->getUpperBoundMap();
FlatAffineConstraints localVarCst;
std::vector<SmallVector<int64_t, 8>> flatExprs;
@ -1321,16 +1323,16 @@ bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) {
return true;
};
if (forInst.hasConstantLowerBound()) {
addConstantLowerBound(pos, forInst.getConstantLowerBound());
if (forOp->hasConstantLowerBound()) {
addConstantLowerBound(pos, forOp->getConstantLowerBound());
} else {
// Non-constant lower bound case.
if (!addLowerOrUpperBound(/*lower=*/true))
return false;
}
if (forInst.hasConstantUpperBound()) {
addConstantUpperBound(pos, forInst.getConstantUpperBound() - 1);
if (forOp->hasConstantUpperBound()) {
addConstantUpperBound(pos, forOp->getConstantUpperBound() - 1);
return true;
}
// Non-constant upper bound case.

View File

@ -43,27 +43,27 @@ using namespace mlir;
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
AffineExpr mlir::getTripCountExpr(const ForInst &forInst) {
AffineExpr mlir::getTripCountExpr(ConstOpPointer<AffineForOp> forOp) {
// upper_bound - lower_bound
int64_t loopSpan;
int64_t step = forInst.getStep();
auto *context = forInst.getContext();
int64_t step = forOp->getStep();
auto *context = forOp->getInstruction()->getContext();
if (forInst.hasConstantBounds()) {
int64_t lb = forInst.getConstantLowerBound();
int64_t ub = forInst.getConstantUpperBound();
if (forOp->hasConstantBounds()) {
int64_t lb = forOp->getConstantLowerBound();
int64_t ub = forOp->getConstantUpperBound();
loopSpan = ub - lb;
} else {
auto lbMap = forInst.getLowerBoundMap();
auto ubMap = forInst.getUpperBoundMap();
auto lbMap = forOp->getLowerBoundMap();
auto ubMap = forOp->getUpperBoundMap();
// TODO(bondhugula): handle max/min of multiple expressions.
if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
return nullptr;
// TODO(bondhugula): handle bounds with different operands.
// Bounds have different operands, unhandled for now.
if (!forInst.matchingBoundOperandList())
if (!forOp->matchingBoundOperandList())
return nullptr;
// ub_expr - lb_expr
@ -89,8 +89,9 @@ AffineExpr mlir::getTripCountExpr(const ForInst &forInst) {
/// Returns the trip count of the loop if it's a constant, None otherwise. This
/// method uses affine expression analysis (in turn using getTripCount) and is
/// able to determine constant trip count in non-trivial cases.
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForInst &forInst) {
auto tripCountExpr = getTripCountExpr(forInst);
llvm::Optional<uint64_t>
mlir::getConstantTripCount(ConstOpPointer<AffineForOp> forOp) {
auto tripCountExpr = getTripCountExpr(forOp);
if (!tripCountExpr)
return None;
@ -104,8 +105,8 @@ llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForInst &forInst) {
/// Returns the greatest known integral divisor of the trip count. Affine
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) {
auto tripCountExpr = getTripCountExpr(forInst);
uint64_t mlir::getLargestDivisorOfTripCount(ConstOpPointer<AffineForOp> forOp) {
auto tripCountExpr = getTripCountExpr(forOp);
if (!tripCountExpr)
return 1;
@ -126,7 +127,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) {
}
bool mlir::isAccessInvariant(const Value &iv, const Value &index) {
assert(isForInductionVar(&iv) && "iv must be a ForInst");
assert(isForInductionVar(&iv) && "iv must be a AffineForOp");
assert(index.getType().isa<IndexType>() && "index must be of IndexType");
SmallVector<OperationInst *, 4> affineApplyOps;
getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps);
@ -163,7 +164,7 @@ mlir::getInvariantAccesses(const Value &iv,
}
/// Given:
/// 1. an induction variable `iv` of type ForInst;
/// 1. an induction variable `iv` of type AffineForOp;
/// 2. a `memoryOp` of type const LoadOp& or const StoreOp&;
/// 3. the index of the `fastestVaryingDim` along which to check;
/// determines whether `memoryOp`[`fastestVaryingDim`] is a contiguous access
@ -231,17 +232,18 @@ static bool isVectorTransferReadOrWrite(const Instruction &inst) {
}
using VectorizableInstFun =
std::function<bool(const ForInst &, const OperationInst &)>;
std::function<bool(ConstOpPointer<AffineForOp>, const OperationInst &)>;
static bool isVectorizableLoopWithCond(const ForInst &loop,
static bool isVectorizableLoopWithCond(ConstOpPointer<AffineForOp> loop,
VectorizableInstFun isVectorizableInst) {
if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) {
auto *forInst = const_cast<OperationInst *>(loop->getInstruction());
if (!matcher::isParallelLoop(*forInst) &&
!matcher::isReductionLoop(*forInst)) {
return false;
}
// No vectorization across conditionals for now.
auto conditionals = matcher::If();
auto *forInst = const_cast<ForInst *>(&loop);
SmallVector<NestedMatch, 8> conditionalsMatched;
conditionals.match(forInst, &conditionalsMatched);
if (!conditionalsMatched.empty()) {
@ -251,7 +253,8 @@ static bool isVectorizableLoopWithCond(const ForInst &loop,
// No vectorization across unknown regions.
auto regions = matcher::Op([](const Instruction &inst) -> bool {
auto &opInst = cast<OperationInst>(inst);
return opInst.getNumBlockLists() != 0 && !opInst.isa<AffineIfOp>();
return opInst.getNumBlockLists() != 0 &&
!(opInst.isa<AffineIfOp>() || opInst.isa<AffineForOp>());
});
SmallVector<NestedMatch, 8> regionsMatched;
regions.match(forInst, &regionsMatched);
@ -288,23 +291,25 @@ static bool isVectorizableLoopWithCond(const ForInst &loop,
}
bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
const ForInst &loop, unsigned fastestVaryingDim) {
VectorizableInstFun fun(
[fastestVaryingDim](const ForInst &loop, const OperationInst &op) {
auto load = op.dyn_cast<LoadOp>();
auto store = op.dyn_cast<StoreOp>();
return load ? isContiguousAccess(*loop.getInductionVar(), *load,
fastestVaryingDim)
: isContiguousAccess(*loop.getInductionVar(), *store,
fastestVaryingDim);
});
ConstOpPointer<AffineForOp> loop, unsigned fastestVaryingDim) {
VectorizableInstFun fun([fastestVaryingDim](ConstOpPointer<AffineForOp> loop,
const OperationInst &op) {
auto load = op.dyn_cast<LoadOp>();
auto store = op.dyn_cast<StoreOp>();
return load ? isContiguousAccess(*loop->getInductionVar(), *load,
fastestVaryingDim)
: isContiguousAccess(*loop->getInductionVar(), *store,
fastestVaryingDim);
});
return isVectorizableLoopWithCond(loop, fun);
}
bool mlir::isVectorizableLoop(const ForInst &loop) {
bool mlir::isVectorizableLoop(ConstOpPointer<AffineForOp> loop) {
VectorizableInstFun fun(
// TODO: implement me
[](const ForInst &loop, const OperationInst &op) { return true; });
[](ConstOpPointer<AffineForOp> loop, const OperationInst &op) {
return true;
});
return isVectorizableLoopWithCond(loop, fun);
}
@ -313,9 +318,9 @@ bool mlir::isVectorizableLoop(const ForInst &loop) {
/// 'def' and all its uses have the same shift factor.
// TODO(mlir-team): extend this to check for memory-based dependence
// violation when we have the support.
bool mlir::isInstwiseShiftValid(const ForInst &forInst,
bool mlir::isInstwiseShiftValid(ConstOpPointer<AffineForOp> forOp,
ArrayRef<uint64_t> shifts) {
auto *forBody = forInst.getBody();
auto *forBody = forOp->getBody();
assert(shifts.size() == forBody->getInstructions().size());
unsigned s = 0;
for (const auto &inst : *forBody) {
@ -325,7 +330,7 @@ bool mlir::isInstwiseShiftValid(const ForInst &forInst,
for (unsigned i = 0, e = opInst->getNumResults(); i < e; ++i) {
const Value *result = opInst->getResult(i);
for (const InstOperand &use : result->getUses()) {
// If an ancestor instruction doesn't lie in the block of forInst,
// If an ancestor instruction doesn't lie in the block of forOp,
// there is no shift to check. This is a naive way. If performance
// becomes an issue, a map can be used to store 'shifts' - to look up
// the shift for a instruction in constant time.

View File

@ -115,6 +115,10 @@ void NestedPattern::matchOne(Instruction *inst,
}
}
static bool isAffineForOp(const Instruction &inst) {
return cast<OperationInst>(inst).isa<AffineForOp>();
}
static bool isAffineIfOp(const Instruction &inst) {
return isa<OperationInst>(inst) &&
cast<OperationInst>(inst).isa<AffineIfOp>();
@ -147,28 +151,34 @@ NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
}
NestedPattern For(NestedPattern child) {
return NestedPattern(Instruction::Kind::For, child, defaultFilterFunction);
return NestedPattern(Instruction::Kind::OperationInst, child, isAffineForOp);
}
NestedPattern For(FilterFunctionType filter, NestedPattern child) {
return NestedPattern(Instruction::Kind::For, child, filter);
return NestedPattern(Instruction::Kind::OperationInst, child,
[=](const Instruction &inst) {
return isAffineForOp(inst) && filter(inst);
});
}
NestedPattern For(ArrayRef<NestedPattern> nested) {
return NestedPattern(Instruction::Kind::For, nested, defaultFilterFunction);
return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineForOp);
}
NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
return NestedPattern(Instruction::Kind::For, nested, filter);
return NestedPattern(Instruction::Kind::OperationInst, nested,
[=](const Instruction &inst) {
return isAffineForOp(inst) && filter(inst);
});
}
// TODO(ntv): parallel annotation on loops.
bool isParallelLoop(const Instruction &inst) {
const auto *loop = cast<ForInst>(&inst);
return (void *)loop || true; // loop->isParallel();
auto loop = cast<OperationInst>(inst).cast<AffineForOp>();
return loop || true; // loop->isParallel();
};
// TODO(ntv): reduction annotation on loops.
bool isReductionLoop(const Instruction &inst) {
const auto *loop = cast<ForInst>(&inst);
return (void *)loop || true; // loop->isReduction();
auto loop = cast<OperationInst>(inst).cast<AffineForOp>();
return loop || true; // loop->isReduction();
};
bool isLoadOrStore(const Instruction &inst) {

View File

@ -20,6 +20,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instructions.h"
@ -52,7 +53,16 @@ void mlir::getForwardSlice(Instruction *inst,
return;
}
if (auto *opInst = dyn_cast<OperationInst>(inst)) {
auto *opInst = cast<OperationInst>(inst);
if (auto forOp = opInst->dyn_cast<AffineForOp>()) {
for (auto &u : forOp->getInductionVar()->getUses()) {
auto *ownerInst = u.getOwner();
if (forwardSlice->count(ownerInst) == 0) {
getForwardSlice(ownerInst, forwardSlice, filter,
/*topLevel=*/false);
}
}
} else {
assert(opInst->getNumResults() <= 1 && "NYI: multiple results");
if (opInst->getNumResults() > 0) {
for (auto &u : opInst->getResult(0)->getUses()) {
@ -63,16 +73,6 @@ void mlir::getForwardSlice(Instruction *inst,
}
}
}
} else if (auto *forInst = dyn_cast<ForInst>(inst)) {
for (auto &u : forInst->getInductionVar()->getUses()) {
auto *ownerInst = u.getOwner();
if (forwardSlice->count(ownerInst) == 0) {
getForwardSlice(ownerInst, forwardSlice, filter,
/*topLevel=*/false);
}
}
} else {
assert(false && "NYI slicing case");
}
// At the top level we reverse to get back the actual topological order.

View File

@ -38,15 +38,17 @@ using namespace mlir;
/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from
/// the outermost 'for' instruction to the innermost one.
void mlir::getLoopIVs(const Instruction &inst,
SmallVectorImpl<ForInst *> *loops) {
SmallVectorImpl<OpPointer<AffineForOp>> *loops) {
auto *currInst = inst.getParentInst();
ForInst *currForInst;
OpPointer<AffineForOp> currAffineForOp;
// Traverse up the hierarchy collecing all 'for' instruction while skipping
// over 'if' instructions.
while (currInst && ((currForInst = dyn_cast<ForInst>(currInst)) ||
cast<OperationInst>(currInst)->isa<AffineIfOp>())) {
if (currForInst)
loops->push_back(currForInst);
while (currInst &&
((currAffineForOp =
cast<OperationInst>(currInst)->dyn_cast<AffineForOp>()) ||
cast<OperationInst>(currInst)->isa<AffineIfOp>())) {
if (currAffineForOp)
loops->push_back(currAffineForOp);
currInst = currInst->getParentInst();
}
std::reverse(loops->begin(), loops->end());
@ -148,7 +150,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
if (rank == 0) {
// A rank 0 memref has a 0-d region.
SmallVector<ForInst *, 4> ivs;
SmallVector<OpPointer<AffineForOp>, 4> ivs;
getLoopIVs(*opInst, &ivs);
SmallVector<Value *, 8> regionSymbols = extractForInductionVars(ivs);
@ -174,12 +176,12 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
unsigned numSymbols = accessMap.getNumSymbols();
// Add inequalties for loop lower/upper bounds.
for (unsigned i = 0; i < numDims + numSymbols; ++i) {
if (auto *loop = getForInductionVarOwner(accessValueMap.getOperand(i))) {
if (auto loop = getForInductionVarOwner(accessValueMap.getOperand(i))) {
// Note that regionCst can now have more dimensions than accessMap if the
// bounds expressions involve outer loops or other symbols.
// TODO(bondhugula): rewrite this to use getInstIndexSet; this way
// conditionals will be handled when the latter supports it.
if (!regionCst->addForInstDomain(*loop))
if (!regionCst->addAffineForOpDomain(loop))
return false;
} else {
// Has to be a valid symbol.
@ -203,14 +205,14 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
// Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
// this memref region is symbolic.
SmallVector<ForInst *, 4> outerIVs;
SmallVector<OpPointer<AffineForOp>, 4> outerIVs;
getLoopIVs(*opInst, &outerIVs);
assert(loopDepth <= outerIVs.size() && "invalid loop depth");
outerIVs.resize(loopDepth);
for (auto *operand : accessValueMap.getOperands()) {
ForInst *iv;
OpPointer<AffineForOp> iv;
if ((iv = getForInductionVarOwner(operand)) &&
std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) {
llvm::is_contained(outerIVs, iv) == false) {
regionCst->projectOut(operand);
}
}
@ -357,8 +359,10 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
}
if (level == positions.size() - 1)
return &inst;
if (auto *childForInst = dyn_cast<ForInst>(&inst))
return getInstAtPosition(positions, level + 1, childForInst->getBody());
if (auto childAffineForOp =
cast<OperationInst>(inst).dyn_cast<AffineForOp>())
return getInstAtPosition(positions, level + 1,
childAffineForOp->getBody());
for (auto &blockList : cast<OperationInst>(&inst)->getBlockLists()) {
for (auto &b : blockList)
@ -385,12 +389,12 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
return false;
}
// Get loop nest surrounding src operation.
SmallVector<ForInst *, 4> srcLoopIVs;
SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
getLoopIVs(*srcAccess.opInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
// Get loop nest surrounding dst operation.
SmallVector<ForInst *, 4> dstLoopIVs;
SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs;
getLoopIVs(*dstAccess.opInst, &dstLoopIVs);
unsigned numDstLoopIVs = dstLoopIVs.size();
if (dstLoopDepth > numDstLoopIVs) {
@ -437,38 +441,41 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
// solution.
// TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project
// out loop IVs we don't care about and produce smaller slice.
ForInst *mlir::insertBackwardComputationSlice(
OpPointer<AffineForOp> mlir::insertBackwardComputationSlice(
OperationInst *srcOpInst, OperationInst *dstOpInst, unsigned dstLoopDepth,
ComputationSliceState *sliceState) {
// Get loop nest surrounding src operation.
SmallVector<ForInst *, 4> srcLoopIVs;
SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
// Get loop nest surrounding dst operation.
SmallVector<ForInst *, 4> dstLoopIVs;
SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs;
getLoopIVs(*dstOpInst, &dstLoopIVs);
unsigned dstLoopIVsSize = dstLoopIVs.size();
if (dstLoopDepth > dstLoopIVsSize) {
dstOpInst->emitError("invalid destination loop depth");
return nullptr;
return OpPointer<AffineForOp>();
}
// Find the inst block positions of 'srcOpInst' within 'srcLoopIVs'.
SmallVector<unsigned, 4> positions;
// TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d.
findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
findInstPosition(srcOpInst, srcLoopIVs[0]->getInstruction()->getBlock(),
&positions);
// Clone src loop nest and insert it a the beginning of the instruction block
// of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
auto *dstForInst = dstLoopIVs[dstLoopDepth - 1];
FuncBuilder b(dstForInst->getBody(), dstForInst->getBody()->begin());
auto *sliceLoopNest = cast<ForInst>(b.clone(*srcLoopIVs[0]));
auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
FuncBuilder b(dstAffineForOp->getBody(), dstAffineForOp->getBody()->begin());
auto sliceLoopNest =
cast<OperationInst>(b.clone(*srcLoopIVs[0]->getInstruction()))
->cast<AffineForOp>();
Instruction *sliceInst =
getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
// Get loop nest surrounding 'sliceInst'.
SmallVector<ForInst *, 4> sliceSurroundingLoops;
SmallVector<OpPointer<AffineForOp>, 4> sliceSurroundingLoops;
getLoopIVs(*sliceInst, &sliceSurroundingLoops);
// Sanity check.
@ -481,11 +488,11 @@ ForInst *mlir::insertBackwardComputationSlice(
// Update loop bounds for loops in 'sliceLoopNest'.
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
auto *forInst = sliceSurroundingLoops[dstLoopDepth + i];
auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
if (AffineMap lbMap = sliceState->lbs[i])
forInst->setLowerBound(sliceState->lbOperands[i], lbMap);
forOp->setLowerBound(sliceState->lbOperands[i], lbMap);
if (AffineMap ubMap = sliceState->ubs[i])
forInst->setUpperBound(sliceState->ubOperands[i], ubMap);
forOp->setUpperBound(sliceState->ubOperands[i], ubMap);
}
return sliceLoopNest;
}
@ -520,7 +527,7 @@ unsigned mlir::getNestingDepth(const Instruction &stmt) {
const Instruction *currInst = &stmt;
unsigned depth = 0;
while ((currInst = currInst->getParentInst())) {
if (isa<ForInst>(currInst))
if (cast<OperationInst>(currInst)->isa<AffineForOp>())
depth++;
}
return depth;
@ -530,14 +537,14 @@ unsigned mlir::getNestingDepth(const Instruction &stmt) {
/// where each lists loops from outer-most to inner-most in loop nest.
unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A,
const Instruction &B) {
SmallVector<ForInst *, 4> loopsA, loopsB;
SmallVector<OpPointer<AffineForOp>, 4> loopsA, loopsB;
getLoopIVs(A, &loopsA);
getLoopIVs(B, &loopsB);
unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
unsigned numCommonLoops = 0;
for (unsigned i = 0; i < minNumLoops; ++i) {
if (loopsA[i] != loopsB[i])
if (loopsA[i]->getInstruction() != loopsB[i]->getInstruction())
break;
++numCommonLoops;
}
@ -571,13 +578,14 @@ static Optional<int64_t> getRegionSize(const MemRefRegion &region) {
return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
}
Optional<int64_t> mlir::getMemoryFootprintBytes(const ForInst &forInst,
int memorySpace) {
Optional<int64_t>
mlir::getMemoryFootprintBytes(ConstOpPointer<AffineForOp> forOp,
int memorySpace) {
std::vector<std::unique_ptr<MemRefRegion>> regions;
// Walk this 'for' instruction to gather all memory regions.
bool error = false;
const_cast<ForInst *>(&forInst)->walkOps([&](OperationInst *opInst) {
const_cast<AffineForOp &>(*forOp).walkOps([&](OperationInst *opInst) {
if (!opInst->isa<LoadOp>() && !opInst->isa<StoreOp>()) {
// Neither load nor a store op.
return;

View File

@ -16,10 +16,12 @@
// =============================================================================
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/Functional.h"
@ -105,7 +107,7 @@ Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType,
static AffineMap makePermutationMap(
MLIRContext *context,
llvm::iterator_range<OperationInst::operand_iterator> indices,
const DenseMap<ForInst *, unsigned> &enclosingLoopToVectorDim) {
const DenseMap<Instruction *, unsigned> &enclosingLoopToVectorDim) {
using functional::makePtrDynCaster;
using functional::map;
auto unwrappedIndices = map(makePtrDynCaster<Value, Value>(), indices);
@ -113,8 +115,9 @@ static AffineMap makePermutationMap(
getAffineConstantExpr(0, context));
for (auto kvp : enclosingLoopToVectorDim) {
assert(kvp.second < perm.size());
auto invariants =
getInvariantAccesses(*kvp.first->getInductionVar(), unwrappedIndices);
auto invariants = getInvariantAccesses(
*cast<OperationInst>(kvp.first)->cast<AffineForOp>()->getInductionVar(),
unwrappedIndices);
unsigned numIndices = unwrappedIndices.size();
unsigned countInvariantIndices = 0;
for (unsigned dim = 0; dim < numIndices; ++dim) {
@ -139,30 +142,30 @@ static AffineMap makePermutationMap(
/// TODO(ntv): could also be implemented as a collect parents followed by a
/// filter and made available outside this file.
template <typename T>
static SetVector<T *> getParentsOfType(Instruction *inst) {
SetVector<T *> res;
static SetVector<OperationInst *> getParentsOfType(Instruction *inst) {
SetVector<OperationInst *> res;
auto *current = inst;
while (auto *parent = current->getParentInst()) {
auto *typedParent = dyn_cast<T>(parent);
if (typedParent) {
assert(res.count(typedParent) == 0 && "Already inserted");
res.insert(typedParent);
if (auto typedParent =
cast<OperationInst>(parent)->template dyn_cast<T>()) {
assert(res.count(cast<OperationInst>(parent)) == 0 && "Already inserted");
res.insert(cast<OperationInst>(parent));
}
current = parent;
}
return res;
}
/// Returns the enclosing ForInst, from closest to farthest.
static SetVector<ForInst *> getEnclosingforInsts(Instruction *inst) {
return getParentsOfType<ForInst>(inst);
/// Returns the enclosing AffineForOp, from closest to farthest.
static SetVector<OperationInst *> getEnclosingforOps(Instruction *inst) {
return getParentsOfType<AffineForOp>(inst);
}
AffineMap
mlir::makePermutationMap(OperationInst *opInst,
const DenseMap<ForInst *, unsigned> &loopToVectorDim) {
DenseMap<ForInst *, unsigned> enclosingLoopToVectorDim;
auto enclosingLoops = getEnclosingforInsts(opInst);
AffineMap mlir::makePermutationMap(
OperationInst *opInst,
const DenseMap<Instruction *, unsigned> &loopToVectorDim) {
DenseMap<Instruction *, unsigned> enclosingLoopToVectorDim;
auto enclosingLoops = getEnclosingforOps(opInst);
for (auto *forInst : enclosingLoops) {
auto it = loopToVectorDim.find(forInst);
if (it != loopToVectorDim.end()) {

View File

@ -72,7 +72,6 @@ public:
bool verify();
bool verifyBlock(const Block &block, bool isTopLevel);
bool verifyOperation(const OperationInst &op);
bool verifyForInst(const ForInst &forInst);
bool verifyDominance(const Block &block);
bool verifyInstDominance(const Instruction &inst);
@ -175,10 +174,6 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) {
if (verifyOperation(cast<OperationInst>(inst)))
return true;
break;
case Instruction::Kind::For:
if (verifyForInst(cast<ForInst>(inst)))
return true;
break;
}
}
@ -240,11 +235,6 @@ bool FuncVerifier::verifyOperation(const OperationInst &op) {
return false;
}
bool FuncVerifier::verifyForInst(const ForInst &forInst) {
// TODO: check that loop bounds are properly formed.
return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false);
}
bool FuncVerifier::verifyDominance(const Block &block) {
for (auto &inst : block) {
// Check that all operands on the instruction are ok.
@ -262,10 +252,6 @@ bool FuncVerifier::verifyDominance(const Block &block) {
return true;
break;
}
case Instruction::Kind::For:
if (verifyDominance(*cast<ForInst>(inst).getBody()))
return true;
break;
}
}
return false;

View File

@ -21,12 +21,14 @@
#include "llvm/Support/raw_ostream.h"
#include "mlir-c/Core.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/EDSC/MLIREmitter.h"
#include "mlir/EDSC/Types.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Value.h"
#include "mlir/StandardOps/StandardOps.h"
@ -133,8 +135,8 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) {
inst->print(os);
return;
}
if (auto *forInst = getForInductionVarOwner(&v)) {
forInst->print(os);
if (auto forInst = getForInductionVarOwner(&v)) {
forInst->getInstruction()->print(os);
} else {
os << "unknown_ssa_value";
}
@ -300,7 +302,9 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) {
exprs[1]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
auto step =
exprs[2]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
res = builder->createFor(location, lb, ub, step)->getInductionVar();
auto forOp = builder->create<AffineForOp>(location, lb, ub, step);
forOp->createBody();
res = forOp->getInductionVar();
}
}

View File

@ -130,21 +130,8 @@ private:
void recordTypeReference(Type ty) { usedTypes.insert(ty); }
// Return true if this map could be printed using the custom assembly form.
static bool hasCustomForm(AffineMap boundMap) {
if (boundMap.isSingleConstant())
return true;
// Check if the affine map is single dim id or single symbol identity -
// (i)->(i) or ()[s]->(i)
return boundMap.getNumInputs() == 1 && boundMap.getNumResults() == 1 &&
(boundMap.getResult(0).isa<AffineDimExpr>() ||
boundMap.getResult(0).isa<AffineSymbolExpr>());
}
// Visit functions.
void visitInstruction(const Instruction *inst);
void visitForInst(const ForInst *forInst);
void visitOperationInst(const OperationInst *opInst);
void visitType(Type type);
void visitAttribute(Attribute attr);
@ -196,16 +183,6 @@ void ModuleState::visitAttribute(Attribute attr) {
}
}
void ModuleState::visitForInst(const ForInst *forInst) {
AffineMap lbMap = forInst->getLowerBoundMap();
if (!hasCustomForm(lbMap))
recordAffineMapReference(lbMap);
AffineMap ubMap = forInst->getUpperBoundMap();
if (!hasCustomForm(ubMap))
recordAffineMapReference(ubMap);
}
void ModuleState::visitOperationInst(const OperationInst *op) {
// Visit all the types used in the operation.
for (auto *operand : op->getOperands())
@ -220,8 +197,6 @@ void ModuleState::visitOperationInst(const OperationInst *op) {
void ModuleState::visitInstruction(const Instruction *inst) {
switch (inst->getKind()) {
case Instruction::Kind::For:
return visitForInst(cast<ForInst>(inst));
case Instruction::Kind::OperationInst:
return visitOperationInst(cast<OperationInst>(inst));
}
@ -1069,7 +1044,6 @@ public:
// Methods to print instructions.
void print(const Instruction *inst);
void print(const OperationInst *inst);
void print(const ForInst *inst);
void print(const Block *block, bool printBlockArgs = true);
void printOperation(const OperationInst *op);
@ -1117,10 +1091,8 @@ public:
unsigned index) override;
/// Print a block list.
void printBlockList(const BlockList &blocks) override {
printBlockList(blocks, /*printEntryBlockArgs=*/true);
}
void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) {
void printBlockList(const BlockList &blocks,
bool printEntryBlockArgs) override {
os << " {\n";
if (!blocks.empty()) {
auto *entryBlock = &blocks.front();
@ -1132,10 +1104,6 @@ public:
os.indent(currentIndent) << "}";
}
// Print if and loop bounds.
void printDimAndSymbolList(ArrayRef<InstOperand> ops, unsigned numDims);
void printBound(AffineBound bound, const char *prefix);
// Number of spaces used for indenting nested instructions.
const static unsigned indentWidth = 2;
@ -1205,10 +1173,6 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
numberValuesInBlock(block);
break;
}
case Instruction::Kind::For:
// Recursively number the stuff in the body.
numberValuesInBlock(*cast<ForInst>(&inst)->getBody());
break;
}
}
}
@ -1404,8 +1368,6 @@ void FunctionPrinter::print(const Instruction *inst) {
switch (inst->getKind()) {
case Instruction::Kind::OperationInst:
return print(cast<OperationInst>(inst));
case Instruction::Kind::For:
return print(cast<ForInst>(inst));
}
}
@ -1415,24 +1377,6 @@ void FunctionPrinter::print(const OperationInst *inst) {
printTrailingLocation(inst->getLoc());
}
void FunctionPrinter::print(const ForInst *inst) {
os.indent(currentIndent) << "for ";
printOperand(inst->getInductionVar());
os << " = ";
printBound(inst->getLowerBound(), "max");
os << " to ";
printBound(inst->getUpperBound(), "min");
if (inst->getStep() != 1)
os << " step " << inst->getStep();
printTrailingLocation(inst->getLoc());
os << " {\n";
print(inst->getBody(), /*printBlockArgs=*/false);
os.indent(currentIndent) << "}";
}
void FunctionPrinter::printValueID(const Value *value,
bool printResultNo) const {
int resultNo = -1;
@ -1560,62 +1504,6 @@ void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term,
os << ')';
}
void FunctionPrinter::printDimAndSymbolList(ArrayRef<InstOperand> ops,
unsigned numDims) {
auto printComma = [&]() { os << ", "; };
os << '(';
interleave(
ops.begin(), ops.begin() + numDims,
[&](const InstOperand &v) { printOperand(v.get()); }, printComma);
os << ')';
if (numDims < ops.size()) {
os << '[';
interleave(
ops.begin() + numDims, ops.end(),
[&](const InstOperand &v) { printOperand(v.get()); }, printComma);
os << ']';
}
}
void FunctionPrinter::printBound(AffineBound bound, const char *prefix) {
AffineMap map = bound.getMap();
// Check if this bound should be printed using custom assembly form.
// The decision to restrict printing custom assembly form to trivial cases
// comes from the will to roundtrip MLIR binary -> text -> binary in a
// lossless way.
// Therefore, custom assembly form parsing and printing is only supported for
// zero-operand constant maps and single symbol operand identity maps.
if (map.getNumResults() == 1) {
AffineExpr expr = map.getResult(0);
// Print constant bound.
if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
os << constExpr.getValue();
return;
}
}
// Print bound that consists of a single SSA symbol if the map is over a
// single symbol.
if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
printOperand(bound.getOperand(0));
return;
}
}
} else {
// Map has multiple results. Print 'min' or 'max' prefix.
os << prefix << ' ';
}
// Print the map and its operands.
printAffineMapReference(map);
printDimAndSymbolList(bound.getInstOperands(), map.getNumDims());
}
// Prints function with initialized module state.
void ModulePrinter::print(const Function *fn) {
FunctionPrinter(fn, *this).print();

View File

@ -312,19 +312,3 @@ OperationInst *FuncBuilder::createOperation(const OperationState &state) {
block->getInstructions().insert(insertPoint, op);
return op;
}
ForInst *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step) {
auto *inst =
ForInst::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
block->getInstructions().insert(insertPoint, inst);
return inst;
}
ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
int64_t step) {
auto lbMap = AffineMap::getConstantMap(lb, context);
auto ubMap = AffineMap::getConstantMap(ub, context);
return createFor(location, {}, lbMap, {}, ubMap, step);
}

View File

@ -143,9 +143,6 @@ void Instruction::destroy() {
case Kind::OperationInst:
cast<OperationInst>(this)->destroy();
break;
case Kind::For:
cast<ForInst>(this)->destroy();
break;
}
}
@ -209,8 +206,6 @@ unsigned Instruction::getNumOperands() const {
switch (getKind()) {
case Kind::OperationInst:
return cast<OperationInst>(this)->getNumOperands();
case Kind::For:
return cast<ForInst>(this)->getNumOperands();
}
}
@ -218,8 +213,6 @@ MutableArrayRef<InstOperand> Instruction::getInstOperands() {
switch (getKind()) {
case Kind::OperationInst:
return cast<OperationInst>(this)->getInstOperands();
case Kind::For:
return cast<ForInst>(this)->getInstOperands();
}
}
@ -349,10 +342,6 @@ void Instruction::dropAllReferences() {
op.drop();
switch (getKind()) {
case Kind::For:
// Make sure to drop references held by instructions within the body.
cast<ForInst>(this)->getBody()->dropAllReferences();
break;
case Kind::OperationInst: {
auto *opInst = cast<OperationInst>(this);
if (isTerminator())
@ -655,217 +644,6 @@ bool OperationInst::emitOpError(const Twine &message) const {
return emitError(Twine('\'') + getName().getStringRef() + "' op " + message);
}
//===----------------------------------------------------------------------===//
// ForInst
//===----------------------------------------------------------------------===//
ForInst *ForInst::create(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step) {
assert((!lbMap && lbOperands.empty()) ||
lbOperands.size() == lbMap.getNumInputs() &&
"lower bound operand count does not match the affine map");
assert((!ubMap && ubOperands.empty()) ||
ubOperands.size() == ubMap.getNumInputs() &&
"upper bound operand count does not match the affine map");
assert(step > 0 && "step has to be a positive integer constant");
// Compute the byte size for the instruction and the operand storage.
unsigned numOperands = lbOperands.size() + ubOperands.size();
auto byteSize = totalSizeToAlloc<detail::OperandStorage>(
/*detail::OperandStorage*/ 1);
byteSize += llvm::alignTo(detail::OperandStorage::additionalAllocSize(
numOperands, /*resizable=*/true),
alignof(ForInst));
void *rawMem = malloc(byteSize);
// Initialize the OperationInst part of the instruction.
ForInst *inst = ::new (rawMem) ForInst(location, lbMap, ubMap, step);
new (&inst->getOperandStorage())
detail::OperandStorage(numOperands, /*resizable=*/true);
auto operands = inst->getInstOperands();
unsigned i = 0;
for (unsigned e = lbOperands.size(); i != e; ++i)
new (&operands[i]) InstOperand(inst, lbOperands[i]);
for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j)
new (&operands[i]) InstOperand(inst, ubOperands[j]);
return inst;
}
ForInst::ForInst(Location location, AffineMap lbMap, AffineMap ubMap,
int64_t step)
: Instruction(Instruction::Kind::For, location), body(this), lbMap(lbMap),
ubMap(ubMap), step(step) {
// The body of a for inst always has one block.
auto *bodyEntry = new Block();
body.push_back(bodyEntry);
// Add an argument to the block for the induction variable.
bodyEntry->addArgument(Type::getIndex(lbMap.getResult(0).getContext()));
}
ForInst::~ForInst() { getOperandStorage().~OperandStorage(); }
const AffineBound ForInst::getLowerBound() const {
return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap);
}
const AffineBound ForInst::getUpperBound() const {
return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap);
}
void ForInst::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
assert(lbOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end());
auto ubOperands = getUpperBoundOperands();
newOperands.append(ubOperands.begin(), ubOperands.end());
getOperandStorage().setOperands(this, newOperands);
this->lbMap = map;
}
void ForInst::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
assert(ubOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<Value *, 4> newOperands(getLowerBoundOperands());
newOperands.append(ubOperands.begin(), ubOperands.end());
getOperandStorage().setOperands(this, newOperands);
this->ubMap = map;
}
void ForInst::setLowerBoundMap(AffineMap map) {
assert(lbMap.getNumDims() == map.getNumDims() &&
lbMap.getNumSymbols() == map.getNumSymbols());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
this->lbMap = map;
}
void ForInst::setUpperBoundMap(AffineMap map) {
assert(ubMap.getNumDims() == map.getNumDims() &&
ubMap.getNumSymbols() == map.getNumSymbols());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
this->ubMap = map;
}
bool ForInst::hasConstantLowerBound() const { return lbMap.isSingleConstant(); }
bool ForInst::hasConstantUpperBound() const { return ubMap.isSingleConstant(); }
int64_t ForInst::getConstantLowerBound() const {
return lbMap.getSingleConstantResult();
}
int64_t ForInst::getConstantUpperBound() const {
return ubMap.getSingleConstantResult();
}
void ForInst::setConstantLowerBound(int64_t value) {
setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
}
void ForInst::setConstantUpperBound(int64_t value) {
setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
}
ForInst::operand_range ForInst::getLowerBoundOperands() {
return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
}
ForInst::const_operand_range ForInst::getLowerBoundOperands() const {
return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
}
ForInst::operand_range ForInst::getUpperBoundOperands() {
return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
}
ForInst::const_operand_range ForInst::getUpperBoundOperands() const {
return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
}
bool ForInst::matchingBoundOperandList() const {
if (lbMap.getNumDims() != ubMap.getNumDims() ||
lbMap.getNumSymbols() != ubMap.getNumSymbols())
return false;
unsigned numOperands = lbMap.getNumInputs();
for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
// Compare Value *'s.
if (getOperand(i) != getOperand(numOperands + i))
return false;
}
return true;
}
void ForInst::walkOps(std::function<void(OperationInst *)> callback) {
struct Walker : public InstWalker<Walker> {
std::function<void(OperationInst *)> const &callback;
Walker(std::function<void(OperationInst *)> const &callback)
: callback(callback) {}
void visitOperationInst(OperationInst *opInst) { callback(opInst); }
};
Walker w(callback);
w.walk(this);
}
void ForInst::walkOpsPostOrder(std::function<void(OperationInst *)> callback) {
struct Walker : public InstWalker<Walker> {
std::function<void(OperationInst *)> const &callback;
Walker(std::function<void(OperationInst *)> const &callback)
: callback(callback) {}
void visitOperationInst(OperationInst *opInst) { callback(opInst); }
};
Walker v(callback);
v.walkPostOrder(this);
}
/// Returns the induction variable for this loop.
Value *ForInst::getInductionVar() { return getBody()->getArgument(0); }
void ForInst::destroy() {
this->~ForInst();
free(this);
}
/// Returns if the provided value is the induction variable of a ForInst.
bool mlir::isForInductionVar(const Value *val) {
return getForInductionVarOwner(val) != nullptr;
}
/// Returns the loop parent of an induction variable. If the provided value is
/// not an induction variable, then return nullptr.
ForInst *mlir::getForInductionVarOwner(Value *val) {
const BlockArgument *ivArg = dyn_cast<BlockArgument>(val);
if (!ivArg || !ivArg->getOwner())
return nullptr;
return dyn_cast_or_null<ForInst>(
ivArg->getOwner()->getParent()->getContainingInst());
}
const ForInst *mlir::getForInductionVarOwner(const Value *val) {
return getForInductionVarOwner(const_cast<Value *>(val));
}
/// Extracts the induction variables from a list of ForInsts and returns them.
SmallVector<Value *, 8>
mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) {
SmallVector<Value *, 8> results;
for (auto *forInst : forInsts)
results.push_back(forInst->getInductionVar());
return results;
}
//===----------------------------------------------------------------------===//
// Instruction Cloning
//===----------------------------------------------------------------------===//
@ -879,84 +657,59 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper,
MLIRContext *context) const {
SmallVector<Value *, 8> operands;
SmallVector<Block *, 2> successors;
if (auto *opInst = dyn_cast<OperationInst>(this)) {
operands.reserve(getNumOperands() + opInst->getNumSuccessors());
if (!opInst->isTerminator()) {
// Non-terminators just add all the operands.
for (auto *opValue : getOperands())
auto *opInst = cast<OperationInst>(this);
operands.reserve(getNumOperands() + opInst->getNumSuccessors());
if (!opInst->isTerminator()) {
// Non-terminators just add all the operands.
for (auto *opValue : getOperands())
operands.push_back(mapper.lookupOrDefault(const_cast<Value *>(opValue)));
} else {
// We add the operands separated by nullptr's for each successor.
unsigned firstSuccOperand = opInst->getNumSuccessors()
? opInst->getSuccessorOperandIndex(0)
: opInst->getNumOperands();
auto InstOperands = opInst->getInstOperands();
unsigned i = 0;
for (; i != firstSuccOperand; ++i)
operands.push_back(
mapper.lookupOrDefault(const_cast<Value *>(InstOperands[i].get())));
successors.reserve(opInst->getNumSuccessors());
for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e; ++succ) {
successors.push_back(mapper.lookupOrDefault(
const_cast<Block *>(opInst->getSuccessor(succ))));
// Add sentinel to delineate successor operands.
operands.push_back(nullptr);
// Remap the successors operands.
for (auto *operand : opInst->getSuccessorOperands(succ))
operands.push_back(
mapper.lookupOrDefault(const_cast<Value *>(opValue)));
} else {
// We add the operands separated by nullptr's for each successor.
unsigned firstSuccOperand = opInst->getNumSuccessors()
? opInst->getSuccessorOperandIndex(0)
: opInst->getNumOperands();
auto InstOperands = opInst->getInstOperands();
unsigned i = 0;
for (; i != firstSuccOperand; ++i)
operands.push_back(
mapper.lookupOrDefault(const_cast<Value *>(InstOperands[i].get())));
successors.reserve(opInst->getNumSuccessors());
for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e;
++succ) {
successors.push_back(mapper.lookupOrDefault(
const_cast<Block *>(opInst->getSuccessor(succ))));
// Add sentinel to delineate successor operands.
operands.push_back(nullptr);
// Remap the successors operands.
for (auto *operand : opInst->getSuccessorOperands(succ))
operands.push_back(
mapper.lookupOrDefault(const_cast<Value *>(operand)));
}
mapper.lookupOrDefault(const_cast<Value *>(operand)));
}
SmallVector<Type, 8> resultTypes;
resultTypes.reserve(opInst->getNumResults());
for (auto *result : opInst->getResults())
resultTypes.push_back(result->getType());
unsigned numBlockLists = opInst->getNumBlockLists();
auto *newOp = OperationInst::create(
getLoc(), opInst->getName(), operands, resultTypes, opInst->getAttrs(),
successors, numBlockLists, opInst->hasResizableOperandsList(), context);
// Clone the block lists.
for (unsigned i = 0; i != numBlockLists; ++i)
opInst->getBlockList(i).cloneInto(&newOp->getBlockList(i), mapper,
context);
// Remember the mapping of any results.
for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i)
mapper.map(opInst->getResult(i), newOp->getResult(i));
return newOp;
}
operands.reserve(getNumOperands());
for (auto *opValue : getOperands())
operands.push_back(mapper.lookupOrDefault(const_cast<Value *>(opValue)));
SmallVector<Type, 8> resultTypes;
resultTypes.reserve(opInst->getNumResults());
for (auto *result : opInst->getResults())
resultTypes.push_back(result->getType());
// Otherwise, this must be a ForInst.
auto *forInst = cast<ForInst>(this);
auto lbMap = forInst->getLowerBoundMap();
auto ubMap = forInst->getUpperBoundMap();
unsigned numBlockLists = opInst->getNumBlockLists();
auto *newOp = OperationInst::create(
getLoc(), opInst->getName(), operands, resultTypes, opInst->getAttrs(),
successors, numBlockLists, opInst->hasResizableOperandsList(), context);
auto *newFor = ForInst::create(
getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()), ubMap,
forInst->getStep());
// Clone the block lists.
for (unsigned i = 0; i != numBlockLists; ++i)
opInst->getBlockList(i).cloneInto(&newOp->getBlockList(i), mapper, context);
// Remember the induction variable mapping.
mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
// Recursively clone the body of the for loop.
for (auto &subInst : *forInst->getBody())
newFor->getBody()->push_back(subInst.clone(mapper, context));
return newFor;
// Remember the mapping of any results.
for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i)
mapper.map(opInst->getResult(i), newOp->getResult(i));
return newOp;
}
Instruction *Instruction::clone(MLIRContext *context) const {

View File

@ -64,8 +64,6 @@ MLIRContext *IROperandOwner::getContext() const {
switch (getKind()) {
case Kind::OperationInst:
return cast<OperationInst>(this)->getContext();
case Kind::ForInst:
return cast<ForInst>(this)->getContext();
}
}

View File

@ -2128,23 +2128,6 @@ public:
parseSuccessors(SmallVectorImpl<Block *> &destinations,
SmallVectorImpl<SmallVector<Value *, 4>> &operands);
ParseResult
parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results,
Block *owner);
ParseResult parseOperationBlockList(SmallVectorImpl<Block *> &results);
ParseResult parseBlockListBody(SmallVectorImpl<Block *> &results);
ParseResult parseBlock(Block *&block);
ParseResult parseBlockBody(Block *block);
/// Cleans up the memory for allocated blocks when a parser error occurs.
void cleanupInvalidBlocks(ArrayRef<Block *> invalidBlocks) {
// Add the referenced blocks to the function so that they can be properly
// cleaned up when the function is destroyed.
for (auto *block : invalidBlocks)
function->push_back(block);
}
/// After the function is finished parsing, this function checks to see if
/// there are any remaining issues.
ParseResult finalizeFunction(SMLoc loc);
@ -2187,6 +2170,25 @@ public:
// Block references.
ParseResult
parseOperationBlockList(SmallVectorImpl<Block *> &results,
ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments);
ParseResult parseBlockListBody(SmallVectorImpl<Block *> &results);
ParseResult parseBlock(Block *&block);
ParseResult parseBlockBody(Block *block);
ParseResult
parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results,
Block *owner);
/// Cleans up the memory for allocated blocks when a parser error occurs.
void cleanupInvalidBlocks(ArrayRef<Block *> invalidBlocks) {
// Add the referenced blocks to the function so that they can be properly
// cleaned up when the function is destroyed.
for (auto *block : invalidBlocks)
function->push_back(block);
}
/// Get the block with the specified name, creating it if it doesn't
/// already exist. The location specified is the point of use, which allows
/// us to diagnose references to blocks that are not defined precisely.
@ -2201,13 +2203,6 @@ public:
OperationInst *parseGenericOperation();
OperationInst *parseCustomOperation();
ParseResult parseForInst();
ParseResult parseIntConstant(int64_t &val);
ParseResult parseDimAndSymbolList(SmallVectorImpl<Value *> &operands,
unsigned numDims, unsigned numOperands,
const char *affineStructName);
ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map,
bool isLower);
ParseResult parseInstructions(Block *block);
private:
@ -2287,25 +2282,43 @@ ParseResult FunctionParser::parseFunctionBody(bool hadNamedArguments) {
///
/// block-list ::= '{' block-list-body
///
ParseResult
FunctionParser::parseOperationBlockList(SmallVectorImpl<Block *> &results) {
ParseResult FunctionParser::parseOperationBlockList(
SmallVectorImpl<Block *> &results,
ArrayRef<std::pair<FunctionParser::SSAUseInfo, Type>> entryArguments) {
// Parse the '{'.
if (parseToken(Token::l_brace, "expected '{' to begin block list"))
return ParseFailure;
// Check for an empty block list.
if (consumeIf(Token::r_brace))
if (entryArguments.empty() && consumeIf(Token::r_brace))
return ParseSuccess;
Block *currentBlock = builder.getInsertionBlock();
// Parse the first block directly to allow for it to be unnamed.
Block *block = new Block();
// Add arguments to the entry block.
for (auto &placeholderArgPair : entryArguments)
if (addDefinition(placeholderArgPair.first,
block->addArgument(placeholderArgPair.second))) {
delete block;
return ParseFailure;
}
if (parseBlock(block)) {
cleanupInvalidBlocks(block);
delete block;
return ParseFailure;
}
results.push_back(block);
// Verify that no other arguments were parsed.
if (!entryArguments.empty() &&
block->getNumArguments() > entryArguments.size()) {
delete block;
return emitError("entry block arguments were already defined");
}
// Parse the rest of the block list.
results.push_back(block);
if (parseBlockListBody(results))
return ParseFailure;
@ -2385,10 +2398,6 @@ ParseResult FunctionParser::parseBlockBody(Block *block) {
if (parseOperation())
return ParseFailure;
break;
case Token::kw_for:
if (parseForInst())
return ParseFailure;
break;
}
}
@ -2859,7 +2868,7 @@ OperationInst *FunctionParser::parseGenericOperation() {
std::vector<SmallVector<Block *, 2>> blocks;
while (getToken().is(Token::l_brace)) {
SmallVector<Block *, 2> newBlocks;
if (parseOperationBlockList(newBlocks)) {
if (parseOperationBlockList(newBlocks, /*entryArguments=*/llvm::None)) {
for (auto &blockList : blocks)
cleanupInvalidBlocks(blockList);
return nullptr;
@ -2884,6 +2893,27 @@ public:
CustomOpAsmParser(SMLoc nameLoc, StringRef opName, FunctionParser &parser)
: nameLoc(nameLoc), opName(opName), parser(parser) {}
bool parseOperation(const AbstractOperation *opDefinition,
OperationState *opState) {
if (opDefinition->parseAssembly(this, opState))
return true;
// Check that enough block lists were reserved for those that were parsed.
if (parsedBlockLists.size() > opState->numBlockLists) {
return emitError(
nameLoc,
"parsed more block lists than those reserved in the operation state");
}
// Check there were no dangling entry block arguments.
if (!parsedBlockListEntryArguments.empty()) {
return emitError(
nameLoc,
"no block list was attached to parsed entry block arguments");
}
return false;
}
//===--------------------------------------------------------------------===//
// High level parsing methods.
//===--------------------------------------------------------------------===//
@ -2895,6 +2925,9 @@ public:
bool parseComma() override {
return parser.parseToken(Token::comma, "expected ','");
}
bool parseEqual() override {
return parser.parseToken(Token::equal, "expected '='");
}
bool parseType(Type &result) override {
return !(result = parser.parseType());
@ -3083,13 +3116,35 @@ public:
/// Parses a list of blocks.
bool parseBlockList() override {
// Parse the block list.
SmallVector<Block *, 2> results;
if (parser.parseOperationBlockList(results))
if (parser.parseOperationBlockList(results, parsedBlockListEntryArguments))
return true;
parsedBlockListEntryArguments.clear();
parsedBlockLists.emplace_back(results);
return false;
}
/// Parses an argument for the entry block of the next block list to be
/// parsed.
bool parseBlockListEntryBlockArgument(Type argType) override {
SmallVector<Value *, 1> argValues;
OperandType operand;
if (parseOperand(operand))
return true;
// Create a place holder for this argument.
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
operand.location};
if (auto *value = parser.resolveSSAUse(operandInfo, argType)) {
parsedBlockListEntryArguments.emplace_back(operandInfo, argType);
return false;
}
return true;
}
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
//===--------------------------------------------------------------------===//
@ -3130,6 +3185,8 @@ public:
private:
std::vector<SmallVector<Block *, 2>> parsedBlockLists;
SmallVector<std::pair<FunctionParser::SSAUseInfo, Type>, 2>
parsedBlockListEntryArguments;
SMLoc nameLoc;
StringRef opName;
FunctionParser &parser;
@ -3161,26 +3218,18 @@ OperationInst *FunctionParser::parseCustomOperation() {
// Have the op implementation take a crack and parsing this.
OperationState opState(builder.getContext(), srcLocation, opName);
if (opDefinition->parseAssembly(&opAsmParser, &opState))
if (opAsmParser.parseOperation(opDefinition, &opState))
return nullptr;
// If it emitted an error, we failed.
if (opAsmParser.didEmitError())
return nullptr;
// Check that enough block lists were reserved for those that were parsed.
auto parsedBlockLists = opAsmParser.getParsedBlockLists();
if (parsedBlockLists.size() > opState.numBlockLists) {
opAsmParser.emitError(
opLoc,
"parsed more block lists than those reserved in the operation state");
return nullptr;
}
// Otherwise, we succeeded. Use the state it parsed as our op information.
auto *opInst = builder.createOperation(opState);
// Resolve any parsed block lists.
auto parsedBlockLists = opAsmParser.getParsedBlockLists();
for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) {
auto &opBlockList = opInst->getBlockList(i).getBlocks();
opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(),
@ -3189,213 +3238,6 @@ OperationInst *FunctionParser::parseCustomOperation() {
return opInst;
}
/// For instruction.
///
/// ml-for-inst ::= `for` ssa-id `=` lower-bound `to` upper-bound
/// (`step` integer-literal)? trailing-location? `{` inst* `}`
///
ParseResult FunctionParser::parseForInst() {
consumeToken(Token::kw_for);
// Parse induction variable.
if (getToken().isNot(Token::percent_identifier))
return emitError("expected SSA identifier for the loop variable");
auto loc = getToken().getLoc();
StringRef inductionVariableName = getTokenSpelling();
consumeToken(Token::percent_identifier);
if (parseToken(Token::equal, "expected '='"))
return ParseFailure;
// Parse lower bound.
SmallVector<Value *, 4> lbOperands;
AffineMap lbMap;
if (parseBound(lbOperands, lbMap, /*isLower*/ true))
return ParseFailure;
if (parseToken(Token::kw_to, "expected 'to' between bounds"))
return ParseFailure;
// Parse upper bound.
SmallVector<Value *, 4> ubOperands;
AffineMap ubMap;
if (parseBound(ubOperands, ubMap, /*isLower*/ false))
return ParseFailure;
// Parse step.
int64_t step = 1;
if (consumeIf(Token::kw_step) && parseIntConstant(step))
return ParseFailure;
// The loop step is a positive integer constant. Since index is stored as an
// int64_t type, we restrict step to be in the set of positive integers that
// int64_t can represent.
if (step < 1) {
return emitError("step has to be a positive integer");
}
// Create for instruction.
ForInst *forInst =
builder.createFor(getEncodedSourceLocation(loc), lbOperands, lbMap,
ubOperands, ubMap, step);
// Create SSA value definition for the induction variable.
if (addDefinition({inductionVariableName, 0, loc},
forInst->getInductionVar()))
return ParseFailure;
// Try to parse the optional trailing location.
if (parseOptionalTrailingLocation(forInst))
return ParseFailure;
// If parsing of the for instruction body fails,
// MLIR contains for instruction with those nested instructions that have been
// successfully parsed.
auto *forBody = forInst->getBody();
if (parseToken(Token::l_brace, "expected '{' before instruction list") ||
parseBlock(forBody) ||
parseToken(Token::r_brace, "expected '}' after instruction list"))
return ParseFailure;
// Reset insertion point to the current block.
builder.setInsertionPointToEnd(forInst->getBlock());
return ParseSuccess;
}
/// Parse integer constant as affine constant expression.
ParseResult FunctionParser::parseIntConstant(int64_t &val) {
bool negate = consumeIf(Token::minus);
if (getToken().isNot(Token::integer))
return emitError("expected integer");
auto uval = getToken().getUInt64IntegerValue();
if (!uval.hasValue() || (int64_t)uval.getValue() < 0) {
return emitError("bound or step is too large for index");
}
val = (int64_t)uval.getValue();
if (negate)
val = -val;
consumeToken();
return ParseSuccess;
}
/// Dimensions and symbol use list.
///
/// dim-use-list ::= `(` ssa-use-list? `)`
/// symbol-use-list ::= `[` ssa-use-list? `]`
/// dim-and-symbol-use-list ::= dim-use-list symbol-use-list?
///
ParseResult
FunctionParser::parseDimAndSymbolList(SmallVectorImpl<Value *> &operands,
unsigned numDims, unsigned numOperands,
const char *affineStructName) {
if (parseToken(Token::l_paren, "expected '('"))
return ParseFailure;
SmallVector<SSAUseInfo, 4> opInfo;
parseOptionalSSAUseList(opInfo);
if (parseToken(Token::r_paren, "expected ')'"))
return ParseFailure;
if (numDims != opInfo.size())
return emitError("dim operand count and " + Twine(affineStructName) +
" dim count must match");
if (consumeIf(Token::l_square)) {
parseOptionalSSAUseList(opInfo);
if (parseToken(Token::r_square, "expected ']'"))
return ParseFailure;
}
if (numOperands != opInfo.size())
return emitError("symbol operand count and " + Twine(affineStructName) +
" symbol count must match");
// Resolve SSA uses.
Type indexType = builder.getIndexType();
for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
Value *sval = resolveSSAUse(opInfo[i], indexType);
if (!sval)
return ParseFailure;
if (i < numDims && !sval->isValidDim())
return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
"' cannot be used as a dimension id");
if (i >= numDims && !sval->isValidSymbol())
return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
"' cannot be used as a symbol");
operands.push_back(sval);
}
return ParseSuccess;
}
// Loop bound.
///
/// lower-bound ::= `max`? affine-map dim-and-symbol-use-list |
/// shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list
/// | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal
///
ParseResult FunctionParser::parseBound(SmallVectorImpl<Value *> &operands,
AffineMap &map, bool isLower) {
// 'min' / 'max' prefixes are syntactic sugar. Ignore them.
if (isLower)
consumeIf(Token::kw_max);
else
consumeIf(Token::kw_min);
// Parse full form - affine map followed by dim and symbol list.
if (getToken().isAny(Token::hash_identifier, Token::l_paren)) {
map = parseAffineMapReference();
if (!map)
return ParseFailure;
if (parseDimAndSymbolList(operands, map.getNumDims(), map.getNumInputs(),
"affine map"))
return ParseFailure;
return ParseSuccess;
}
// Parse custom assembly form.
if (getToken().isAny(Token::minus, Token::integer)) {
int64_t val;
if (!parseIntConstant(val)) {
map = builder.getConstantAffineMap(val);
return ParseSuccess;
}
return ParseFailure;
}
// Parse ssa-id as identity map.
SSAUseInfo opInfo;
if (parseSSAUse(opInfo))
return ParseFailure;
// TODO: improve error message when SSA value is not an affine integer.
// Currently it is 'use of value ... expects different type than prior uses'
if (auto *value = resolveSSAUse(opInfo, builder.getIndexType()))
operands.push_back(value);
else
return ParseFailure;
// Create an identity map using dim id for an induction variable and
// symbol otherwise. This representation is optimized for storage.
// Analysis passes may expand it into a multi-dimensional map if desired.
if (isForInductionVar(operands[0]))
map = builder.getDimIdentityMap();
else
map = builder.getSymbolIdentityMap();
return ParseSuccess;
}
/// Parse an affine constraint.
/// affine-constraint ::= affine-expr `>=` `0`
/// | affine-expr `==` `0`

View File

@ -183,11 +183,6 @@ void CSE::simplifyBlock(Block *bb) {
}
break;
}
case Instruction::Kind::For: {
ScopedMapTy::ScopeTy scope(knownValues);
simplifyBlock(cast<ForInst>(i).getBody());
break;
}
}
}
}

View File

@ -15,6 +15,7 @@
// limitations under the License.
// =============================================================================
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/InstVisitor.h"
@ -37,7 +38,6 @@ struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
bool foldOperation(OperationInst *op,
SmallVectorImpl<Value *> &existingConstants);
void visitOperationInst(OperationInst *inst);
void visitForInst(ForInst *inst);
PassResult runOnFunction(Function *f) override;
static char passID;
@ -50,6 +50,12 @@ char ConstantFold::passID = 0;
/// constants are found, we keep track of them in the existingConstants list.
///
void ConstantFold::visitOperationInst(OperationInst *op) {
// If this operation is an AffineForOp, then fold the bounds.
if (auto forOp = op->dyn_cast<AffineForOp>()) {
constantFoldBounds(forOp);
return;
}
// If this operation is already a constant, just remember it for cleanup
// later, and don't try to fold it.
if (auto constant = op->dyn_cast<ConstantOp>()) {
@ -98,11 +104,6 @@ void ConstantFold::visitOperationInst(OperationInst *op) {
opInstsToErase.push_back(op);
}
// Override the walker's 'for' instruction visit for constant folding.
void ConstantFold::visitForInst(ForInst *forInst) {
constantFoldBounds(forInst);
}
// For now, we do a simple top-down pass over a function folding constants. We
// don't handle conditional control flow, block arguments, folding
// conditional branches, or anything else fancy.

View File

@ -21,6 +21,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
@ -71,9 +72,9 @@ struct DmaGeneration : public FunctionPass {
}
PassResult runOnFunction(Function *f) override;
void runOnForInst(ForInst *forInst);
void runOnAffineForOp(OpPointer<AffineForOp> forOp);
bool generateDma(const MemRefRegion &region, ForInst *forInst,
bool generateDma(const MemRefRegion &region, OpPointer<AffineForOp> forOp,
uint64_t *sizeInBytes);
// List of memory regions to DMA for. We need a map vector to have a
@ -174,7 +175,7 @@ static bool getFullMemRefAsRegion(OperationInst *opInst,
// Just get the first numSymbols IVs, which the memref region is parametric
// on.
SmallVector<ForInst *, 4> ivs;
SmallVector<OpPointer<AffineForOp>, 4> ivs;
getLoopIVs(*opInst, &ivs);
ivs.resize(numParamLoopIVs);
SmallVector<Value *, 4> symbols = extractForInductionVars(ivs);
@ -195,8 +196,10 @@ static bool getFullMemRefAsRegion(OperationInst *opInst,
// generates a DMA from the lower memory space to this one, and replaces all
// loads to load from that buffer. Returns false if DMAs could not be generated
// due to yet unimplemented cases.
bool DmaGeneration::generateDma(const MemRefRegion &region, ForInst *forInst,
bool DmaGeneration::generateDma(const MemRefRegion &region,
OpPointer<AffineForOp> forOp,
uint64_t *sizeInBytes) {
auto *forInst = forOp->getInstruction();
// DMAs for read regions are going to be inserted just before the for loop.
FuncBuilder prologue(forInst);
@ -386,39 +389,43 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForInst *forInst,
remapExprs.push_back(dimExpr - offsets[i]);
}
auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
// *Only* those uses within the body of 'forInst' are replaced.
// *Only* those uses within the body of 'forOp' are replaced.
replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
/*domInstFilter=*/&*forInst->getBody()->begin());
/*domInstFilter=*/&*forOp->getBody()->begin());
return true;
}
// TODO(bondhugula): make this run on a Block instead of a 'for' inst.
void DmaGeneration::runOnForInst(ForInst *forInst) {
void DmaGeneration::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
// For now (for testing purposes), we'll run this on the outermost among 'for'
// inst's with unit stride, i.e., right at the top of the tile if tiling has
// been done. In the future, the DMA generation has to be done at a level
// where the generated data fits in a higher level of the memory hierarchy; so
// the pass has to be instantiated with additional information that we aren't
// provided with at the moment.
if (forInst->getStep() != 1) {
if (auto *innerFor = dyn_cast<ForInst>(&*forInst->getBody()->begin())) {
runOnForInst(innerFor);
if (forOp->getStep() != 1) {
auto *forBody = forOp->getBody();
if (forBody->empty())
return;
if (auto innerFor =
cast<OperationInst>(forBody->front()).dyn_cast<AffineForOp>()) {
runOnAffineForOp(innerFor);
}
return;
}
// DMAs will be generated for this depth, i.e., for all data accessed by this
// loop.
unsigned dmaDepth = getNestingDepth(*forInst);
unsigned dmaDepth = getNestingDepth(*forOp->getInstruction());
readRegions.clear();
writeRegions.clear();
fastBufferMap.clear();
// Walk this 'for' instruction to gather all memory regions.
forInst->walkOps([&](OperationInst *opInst) {
forOp->walkOps([&](OperationInst *opInst) {
// Gather regions to promote to buffers in faster memory space.
// TODO(bondhugula): handle store op's; only load's handled for now.
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
@ -443,7 +450,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) {
LLVM_DEBUG(
forInst->emitError("Non-constant memref sizes not yet supported"));
forOp->emitError("Non-constant memref sizes not yet supported"));
return;
}
}
@ -472,10 +479,10 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
// Perform a union with the existing region.
if (!(*it).second->unionBoundingBox(*region)) {
LLVM_DEBUG(llvm::dbgs()
<< "Memory region bounding box failed; "
<< "Memory region bounding box failed"
"over-approximating to the entire memref\n");
if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) {
LLVM_DEBUG(forInst->emitError(
LLVM_DEBUG(forOp->emitError(
"Non-constant memref sizes not yet supported"));
}
}
@ -501,7 +508,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
&regions) {
for (const auto &regionEntry : regions) {
uint64_t sizeInBytes;
bool iRet = generateDma(*regionEntry.second, forInst, &sizeInBytes);
bool iRet = generateDma(*regionEntry.second, forOp, &sizeInBytes);
if (iRet)
totalSizeInBytes += sizeInBytes;
ret = ret & iRet;
@ -510,7 +517,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
processRegions(readRegions);
processRegions(writeRegions);
if (!ret) {
forInst->emitError("DMA generation failed for one or more memref's\n");
forOp->emitError("DMA generation failed for one or more memref's\n");
return;
}
LLVM_DEBUG(llvm::dbgs() << Twine(llvm::divideCeil(totalSizeInBytes, 1024))
@ -519,7 +526,7 @@ void DmaGeneration::runOnForInst(ForInst *forInst) {
if (clFastMemoryCapacity && totalSizeInBytes > clFastMemoryCapacity) {
// TODO(bondhugula): selecting the DMA depth so that the result DMA buffers
// fit in fast memory is a TODO - not complex.
forInst->emitError(
forOp->emitError(
"Total size of all DMA buffers' exceeds memory capacity\n");
}
}
@ -531,8 +538,8 @@ PassResult DmaGeneration::runOnFunction(Function *f) {
for (auto &block : *f) {
for (auto &inst : block) {
if (auto *forInst = dyn_cast<ForInst>(&inst)) {
runOnForInst(forInst);
if (auto forOp = cast<OperationInst>(inst).dyn_cast<AffineForOp>()) {
runOnAffineForOp(forOp);
}
}
}

View File

@ -97,15 +97,15 @@ namespace {
// operations, and whether or not an IfInst was encountered in the loop nest.
class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
public:
SmallVector<ForInst *, 4> forInsts;
SmallVector<OpPointer<AffineForOp>, 4> forOps;
SmallVector<OperationInst *, 4> loadOpInsts;
SmallVector<OperationInst *, 4> storeOpInsts;
bool hasNonForRegion = false;
void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
void visitOperationInst(OperationInst *opInst) {
if (opInst->getNumBlockLists() != 0)
if (opInst->isa<AffineForOp>())
forOps.push_back(opInst->cast<AffineForOp>());
else if (opInst->getNumBlockLists() != 0)
hasNonForRegion = true;
else if (opInst->isa<LoadOp>())
loadOpInsts.push_back(opInst);
@ -491,14 +491,14 @@ bool MemRefDependenceGraph::init(Function *f) {
if (f->getBlocks().size() != 1)
return false;
DenseMap<ForInst *, unsigned> forToNodeMap;
DenseMap<Instruction *, unsigned> forToNodeMap;
for (auto &inst : f->front()) {
if (auto *forInst = dyn_cast<ForInst>(&inst)) {
// Create graph node 'id' to represent top-level 'forInst' and record
if (auto forOp = cast<OperationInst>(&inst)->dyn_cast<AffineForOp>()) {
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
LoopNestStateCollector collector;
collector.walkForInst(forInst);
// Return false if IfInsts are found (not currently supported).
collector.walk(&inst);
// Return false if a non 'for' region was found (not currently supported).
if (collector.hasNonForRegion)
return false;
Node node(nextNodeId++, &inst);
@ -512,10 +512,9 @@ bool MemRefDependenceGraph::init(Function *f) {
auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
}
forToNodeMap[forInst] = node.id;
forToNodeMap[&inst] = node.id;
nodes.insert({node.id, node});
}
if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
} else if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
// Create graph node for top-level load op.
Node node(nextNodeId++, &inst);
@ -552,12 +551,12 @@ bool MemRefDependenceGraph::init(Function *f) {
for (auto *value : opInst->getResults()) {
for (auto &use : value->getUses()) {
auto *userOpInst = cast<OperationInst>(use.getOwner());
SmallVector<ForInst *, 4> loops;
SmallVector<OpPointer<AffineForOp>, 4> loops;
getLoopIVs(*userOpInst, &loops);
if (loops.empty())
continue;
assert(forToNodeMap.count(loops[0]) > 0);
unsigned userLoopNestId = forToNodeMap[loops[0]];
assert(forToNodeMap.count(loops[0]->getInstruction()) > 0);
unsigned userLoopNestId = forToNodeMap[loops[0]->getInstruction()];
addEdge(node.id, userLoopNestId, value);
}
}
@ -587,12 +586,12 @@ namespace {
// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
// and operation count) for a loop nest up until the innermost loop body.
struct LoopNestStats {
// Map from ForInst to immediate child ForInsts in its loop body.
DenseMap<ForInst *, SmallVector<ForInst *, 2>> loopMap;
// Map from ForInst to count of operations in its loop body.
DenseMap<ForInst *, uint64_t> opCountMap;
// Map from ForInst to its constant trip count.
DenseMap<ForInst *, uint64_t> tripCountMap;
// Map from AffineForOp to immediate child AffineForOps in its loop body.
DenseMap<Instruction *, SmallVector<OpPointer<AffineForOp>, 2>> loopMap;
// Map from AffineForOp to count of operations in its loop body.
DenseMap<Instruction *, uint64_t> opCountMap;
// Map from AffineForOp to its constant trip count.
DenseMap<Instruction *, uint64_t> tripCountMap;
};
// LoopNestStatsCollector walks a single loop nest and gathers per-loop
@ -604,23 +603,31 @@ public:
LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
void visitForInst(ForInst *forInst) {
auto *parentInst = forInst->getParentInst();
void visitOperationInst(OperationInst *opInst) {
auto forOp = opInst->dyn_cast<AffineForOp>();
if (!forOp)
return;
auto *forInst = forOp->getInstruction();
auto *parentInst = forOp->getInstruction()->getParentInst();
if (parentInst != nullptr) {
assert(isa<ForInst>(parentInst) && "Expected parent ForInst");
// Add mapping to 'forInst' from its parent ForInst.
stats->loopMap[cast<ForInst>(parentInst)].push_back(forInst);
assert(cast<OperationInst>(parentInst)->isa<AffineForOp>() &&
"Expected parent AffineForOp");
// Add mapping to 'forOp' from its parent AffineForOp.
stats->loopMap[parentInst].push_back(forOp);
}
// Record the number of op instructions in the body of 'forInst'.
// Record the number of op instructions in the body of 'forOp'.
unsigned count = 0;
stats->opCountMap[forInst] = 0;
for (auto &inst : *forInst->getBody()) {
if (isa<OperationInst>(&inst))
for (auto &inst : *forOp->getBody()) {
if (!(cast<OperationInst>(inst).isa<AffineForOp>() ||
cast<OperationInst>(inst).isa<AffineIfOp>()))
++count;
}
stats->opCountMap[forInst] = count;
// Record trip count for 'forInst'. Set flag if trip count is not constant.
Optional<uint64_t> maybeConstTripCount = getConstantTripCount(*forInst);
// Record trip count for 'forOp'. Set flag if trip count is not constant.
Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
if (!maybeConstTripCount.hasValue()) {
hasLoopWithNonConstTripCount = true;
return;
@ -629,7 +636,7 @@ public:
}
};
// Computes the total cost of the loop nest rooted at 'forInst'.
// Computes the total cost of the loop nest rooted at 'forOp'.
// Currently, the total cost is computed by counting the total operation
// instance count (i.e. total number of operations in the loop bodyloop
// operation count * loop trip count) for the entire loop nest.
@ -637,7 +644,7 @@ public:
// specified in the map when computing the total op instance count.
// NOTE: this is used to compute the cost of computation slices, which are
// sliced along the iteration dimension, and thus reduce the trip count.
// If 'computeCostMap' is non-null, the total op count for forInsts specified
// If 'computeCostMap' is non-null, the total op count for forOps specified
// in the map is increased (not overridden) by adding the op count from the
// map to the existing op count for the for loop. This is done before
// multiplying by the loop's trip count, and is used to model the cost of
@ -645,15 +652,15 @@ public:
// NOTE: this is used to compute the cost of fusing a slice of some loop nest
// within another loop.
static int64_t getComputeCost(
ForInst *forInst, LoopNestStats *stats,
llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap,
DenseMap<ForInst *, int64_t> *computeCostMap) {
// 'opCount' is the total number operations in one iteration of 'forInst' body
Instruction *forInst, LoopNestStats *stats,
llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountOverrideMap,
DenseMap<Instruction *, int64_t> *computeCostMap) {
// 'opCount' is the total number operations in one iteration of 'forOp' body
int64_t opCount = stats->opCountMap[forInst];
if (stats->loopMap.count(forInst) > 0) {
for (auto *childForInst : stats->loopMap[forInst]) {
opCount += getComputeCost(childForInst, stats, tripCountOverrideMap,
computeCostMap);
for (auto childForOp : stats->loopMap[forInst]) {
opCount += getComputeCost(childForOp->getInstruction(), stats,
tripCountOverrideMap, computeCostMap);
}
}
// Add in additional op instances from slice (if specified in map).
@ -694,18 +701,18 @@ static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
return cExpr.getValue();
}
// Builds a map 'tripCountMap' from ForInst to constant trip count for loop
// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
// Returns true on success, false otherwise (if a non-constant trip count
// was encountered).
// TODO(andydavis) Make this work with non-unit step loops.
static bool buildSliceTripCountMap(
OperationInst *srcOpInst, ComputationSliceState *sliceState,
llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountMap) {
SmallVector<ForInst *, 4> srcLoopIVs;
llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountMap) {
SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
// Populate map from ForInst -> trip count
// Populate map from AffineForOp -> trip count
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
AffineMap lbMap = sliceState->lbs[i];
AffineMap ubMap = sliceState->ubs[i];
@ -713,7 +720,7 @@ static bool buildSliceTripCountMap(
// The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
if (srcLoopIVs[i]->hasConstantLowerBound() &&
srcLoopIVs[i]->hasConstantUpperBound()) {
(*tripCountMap)[srcLoopIVs[i]] =
(*tripCountMap)[srcLoopIVs[i]->getInstruction()] =
srcLoopIVs[i]->getConstantUpperBound() -
srcLoopIVs[i]->getConstantLowerBound();
continue;
@ -723,7 +730,7 @@ static bool buildSliceTripCountMap(
Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
if (!tripCount.hasValue())
return false;
(*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue();
(*tripCountMap)[srcLoopIVs[i]->getInstruction()] = tripCount.getValue();
}
return true;
}
@ -750,7 +757,7 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
unsigned numOps = ops.size();
assert(numOps > 0);
std::vector<SmallVector<ForInst *, 4>> loops(numOps);
std::vector<SmallVector<OpPointer<AffineForOp>, 4>> loops(numOps);
unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
for (unsigned i = 0; i < numOps; ++i) {
getLoopIVs(*ops[i], &loops[i]);
@ -762,9 +769,8 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
for (unsigned d = 0; d < loopDepthLimit; ++d) {
unsigned i;
for (i = 1; i < numOps; ++i) {
if (loops[i - 1][d] != loops[i][d]) {
if (loops[i - 1][d] != loops[i][d])
break;
}
}
if (i != numOps)
break;
@ -871,14 +877,16 @@ static bool getSliceUnion(const ComputationSliceState &sliceStateA,
}
// Creates and returns a private (single-user) memref for fused loop rooted
// at 'forInst', with (potentially reduced) memref size based on the
// at 'forOp', with (potentially reduced) memref size based on the
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
// TODO(bondhugula): consider refactoring the common code from generateDma and
// this one.
static Value *createPrivateMemRef(ForInst *forInst,
static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
OperationInst *srcStoreOpInst,
unsigned dstLoopDepth) {
// Create builder to insert alloc op just before 'forInst'.
auto *forInst = forOp->getInstruction();
// Create builder to insert alloc op just before 'forOp'.
FuncBuilder b(forInst);
// Builder to create constants at the top level.
FuncBuilder top(forInst->getFunction());
@ -934,16 +942,16 @@ static Value *createPrivateMemRef(ForInst *forInst,
for (auto dimSize : oldMemRefType.getShape()) {
if (dimSize == -1)
allocOperands.push_back(
top.create<DimOp>(forInst->getLoc(), oldMemRef, dynamicDimCount++));
top.create<DimOp>(forOp->getLoc(), oldMemRef, dynamicDimCount++));
}
// Create new private memref for fused loop 'forInst'.
// Create new private memref for fused loop 'forOp'.
// TODO(andydavis) Create/move alloc ops for private memrefs closer to their
// consumer loop nests to reduce their live range. Currently they are added
// at the beginning of the function, because loop nests can be reordered
// during the fusion pass.
Value *newMemRef =
top.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
top.create<AllocOp>(forOp->getLoc(), newMemRefType, allocOperands);
// Build an AffineMap to remap access functions based on lower bound offsets.
SmallVector<AffineExpr, 4> remapExprs;
@ -967,7 +975,7 @@ static Value *createPrivateMemRef(ForInst *forInst,
bool ret =
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
/*extraOperands=*/outerIVs,
/*domInstFilter=*/&*forInst->getBody()->begin());
/*domInstFilter=*/&*forOp->getBody()->begin());
assert(ret && "replaceAllMemrefUsesWith should always succeed here");
(void)ret;
return newMemRef;
@ -975,7 +983,7 @@ static Value *createPrivateMemRef(ForInst *forInst,
// Does the slice have a single iteration?
static uint64_t getSliceIterationCount(
const llvm::SmallDenseMap<ForInst *, uint64_t, 8> &sliceTripCountMap) {
const llvm::SmallDenseMap<Instruction *, uint64_t, 8> &sliceTripCountMap) {
uint64_t iterCount = 1;
for (const auto &count : sliceTripCountMap) {
iterCount *= count.second;
@ -1030,25 +1038,25 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
});
// Compute cost of sliced and unsliced src loop nest.
SmallVector<ForInst *, 4> srcLoopIVs;
SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
// Walk src loop nest and collect stats.
LoopNestStats srcLoopNestStats;
LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
srcStatsCollector.walk(srcLoopIVs[0]);
srcStatsCollector.walk(srcLoopIVs[0]->getInstruction());
// Currently only constant trip count loop nests are supported.
if (srcStatsCollector.hasLoopWithNonConstTripCount)
return false;
// Compute cost of dst loop nest.
SmallVector<ForInst *, 4> dstLoopIVs;
SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs;
getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
LoopNestStats dstLoopNestStats;
LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
dstStatsCollector.walk(dstLoopIVs[0]);
dstStatsCollector.walk(dstLoopIVs[0]->getInstruction());
// Currently only constant trip count loop nests are supported.
if (dstStatsCollector.hasLoopWithNonConstTripCount)
return false;
@ -1075,17 +1083,19 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
Optional<unsigned> bestDstLoopDepth = None;
// Compute op instance count for the src loop nest without iteration slicing.
uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
/*tripCountOverrideMap=*/nullptr,
/*computeCostMap=*/nullptr);
uint64_t srcLoopNestCost =
getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
/*tripCountOverrideMap=*/nullptr,
/*computeCostMap=*/nullptr);
// Compute op instance count for the src loop nest.
uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
/*tripCountOverrideMap=*/nullptr,
/*computeCostMap=*/nullptr);
uint64_t dstLoopNestCost =
getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
/*tripCountOverrideMap=*/nullptr,
/*computeCostMap=*/nullptr);
llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap;
DenseMap<ForInst *, int64_t> computeCostMap;
llvm::SmallDenseMap<Instruction *, uint64_t, 8> sliceTripCountMap;
DenseMap<Instruction *, int64_t> computeCostMap;
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
MemRefAccess srcAccess(srcOpInst);
// Handle the common case of one dst load without a copy.
@ -1121,24 +1131,25 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
// The store and loads to this memref will disappear.
if (storeLoadFwdGuaranteed) {
// A single store disappears: -1 for that.
computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1;
computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1;
for (auto *loadOp : dstLoadOpInsts) {
if (auto *loadLoop = dyn_cast_or_null<ForInst>(loadOp->getParentInst()))
computeCostMap[loadLoop] = -1;
auto *parentInst = loadOp->getParentInst();
if (parentInst && cast<OperationInst>(parentInst)->isa<AffineForOp>())
computeCostMap[parentInst] = -1;
}
}
// Compute op instance count for the src loop nest with iteration slicing.
int64_t sliceComputeCost =
getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
/*tripCountOverrideMap=*/&sliceTripCountMap,
/*computeCostMap=*/&computeCostMap);
// Compute cost of fusion for this depth.
computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost;
computeCostMap[dstLoopIVs[i - 1]->getInstruction()] = sliceComputeCost;
int64_t fusedLoopNestComputeCost =
getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
/*tripCountOverrideMap=*/nullptr, &computeCostMap);
double additionalComputeFraction =
@ -1211,8 +1222,8 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
<< "\n fused loop nest compute cost: "
<< minFusedLoopNestComputeCost << "\n");
auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]);
auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]);
auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
Optional<double> storageReduction = None;
@ -1292,9 +1303,9 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
//
// *) A worklist is initialized with node ids from the dependence graph.
// *) For each node id in the worklist:
// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate
// destination ForInst into which fusion will be attempted.
// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'.
// *) Pop a AffineForOp of the worklist. This 'dstAffineForOp' will be a
// candidate destination AffineForOp into which fusion will be attempted.
// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
// *) For each LoadOp in 'dstLoadOps' do:
// *) Lookup dependent loop nests at earlier positions in the Function
// which have a single store op to the same memref.
@ -1342,7 +1353,7 @@ public:
// Get 'dstNode' into which to attempt fusion.
auto *dstNode = mdg->getNode(dstId);
// Skip if 'dstNode' is not a loop nest.
if (!isa<ForInst>(dstNode->inst))
if (!cast<OperationInst>(dstNode->inst)->isa<AffineForOp>())
continue;
SmallVector<OperationInst *, 4> loads = dstNode->loads;
@ -1375,7 +1386,7 @@ public:
// Get 'srcNode' from which to attempt fusion into 'dstNode'.
auto *srcNode = mdg->getNode(srcId);
// Skip if 'srcNode' is not a loop nest.
if (!isa<ForInst>(srcNode->inst))
if (!cast<OperationInst>(srcNode->inst)->isa<AffineForOp>())
continue;
// Skip if 'srcNode' has more than one store to any memref.
// TODO(andydavis) Support fusing multi-output src loop nests.
@ -1417,25 +1428,26 @@ public:
continue;
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
auto sliceLoopNest = mlir::insertBackwardComputationSlice(
srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
if (sliceLoopNest != nullptr) {
// Move 'dstForInst' before 'insertPointInst' if needed.
auto *dstForInst = cast<ForInst>(dstNode->inst);
if (insertPointInst != dstForInst) {
dstForInst->moveBefore(insertPointInst);
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
auto dstAffineForOp =
cast<OperationInst>(dstNode->inst)->cast<AffineForOp>();
if (insertPointInst != dstAffineForOp->getInstruction()) {
dstAffineForOp->getInstruction()->moveBefore(insertPointInst);
}
// Update edges between 'srcNode' and 'dstNode'.
mdg->updateEdges(srcNode->id, dstNode->id, memref);
// Collect slice loop stats.
LoopNestStateCollector sliceCollector;
sliceCollector.walkForInst(sliceLoopNest);
sliceCollector.walk(sliceLoopNest->getInstruction());
// Promote single iteration slice loops to single IV value.
for (auto *forInst : sliceCollector.forInsts) {
promoteIfSingleIteration(forInst);
for (auto forOp : sliceCollector.forOps) {
promoteIfSingleIteration(forOp);
}
// Create private memref for 'memref' in 'dstForInst'.
// Create private memref for 'memref' in 'dstAffineForOp'.
SmallVector<OperationInst *, 4> storesForMemref;
for (auto *storeOpInst : sliceCollector.storeOpInsts) {
if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
@ -1443,7 +1455,7 @@ public:
}
assert(storesForMemref.size() == 1);
auto *newMemRef = createPrivateMemRef(
dstForInst, storesForMemref[0], bestDstLoopDepth);
dstAffineForOp, storesForMemref[0], bestDstLoopDepth);
visitedMemrefs.insert(newMemRef);
// Create new node in dependence graph for 'newMemRef' alloc op.
unsigned newMemRefNodeId =
@ -1453,7 +1465,7 @@ public:
// Collect dst loop stats after memref privatizaton transformation.
LoopNestStateCollector dstLoopCollector;
dstLoopCollector.walkForInst(dstForInst);
dstLoopCollector.walk(dstAffineForOp->getInstruction());
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.
@ -1472,7 +1484,7 @@ public:
// function.
if (mdg->canRemoveNode(srcNode->id)) {
mdg->removeNode(srcNode->id);
cast<ForInst>(srcNode->inst)->erase();
srcNode->inst->erase();
}
}
}

View File

@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/LoopAnalysis.h"
@ -60,16 +61,17 @@ char LoopTiling::passID = 0;
/// Function.
FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); }
// Move the loop body of ForInst 'src' from 'src' into the specified location in
// destination's body.
static inline void moveLoopBody(ForInst *src, ForInst *dest,
// Move the loop body of AffineForOp 'src' from 'src' into the specified
// location in destination's body.
static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest,
Block::iterator loc) {
dest->getBody()->getInstructions().splice(loc,
src->getBody()->getInstructions());
}
// Move the loop body of ForInst 'src' from 'src' to the start of dest's body.
static inline void moveLoopBody(ForInst *src, ForInst *dest) {
// Move the loop body of AffineForOp 'src' from 'src' to the start of dest's
// body.
static inline void moveLoopBody(AffineForOp *src, AffineForOp *dest) {
moveLoopBody(src, dest, dest->getBody()->begin());
}
@ -78,13 +80,14 @@ static inline void moveLoopBody(ForInst *src, ForInst *dest) {
/// depend on other dimensions. Bounds of each dimension can thus be treated
/// independently, and deriving the new bounds is much simpler and faster
/// than for the case of tiling arbitrary polyhedral shapes.
static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
ArrayRef<ForInst *> newLoops,
ArrayRef<unsigned> tileSizes) {
static void constructTiledIndexSetHyperRect(
MutableArrayRef<OpPointer<AffineForOp>> origLoops,
MutableArrayRef<OpPointer<AffineForOp>> newLoops,
ArrayRef<unsigned> tileSizes) {
assert(!origLoops.empty());
assert(origLoops.size() == tileSizes.size());
FuncBuilder b(origLoops[0]);
FuncBuilder b(origLoops[0]->getInstruction());
unsigned width = origLoops.size();
// Bounds for tile space loops.
@ -99,8 +102,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
}
// Bounds for intra-tile loops.
for (unsigned i = 0; i < width; i++) {
int64_t largestDiv = getLargestDivisorOfTripCount(*origLoops[i]);
auto mayBeConstantCount = getConstantTripCount(*origLoops[i]);
int64_t largestDiv = getLargestDivisorOfTripCount(origLoops[i]);
auto mayBeConstantCount = getConstantTripCount(origLoops[i]);
// The lower bound is just the tile-space loop.
AffineMap lbMap = b.getDimIdentityMap();
newLoops[width + i]->setLowerBound(
@ -144,38 +147,40 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
/// Tiles the specified band of perfectly nested loops creating tile-space loops
/// and intra-tile loops. A band is a contiguous set of loops.
// TODO(bondhugula): handle non hyper-rectangular spaces.
UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
UtilResult mlir::tileCodeGen(MutableArrayRef<OpPointer<AffineForOp>> band,
ArrayRef<unsigned> tileSizes) {
assert(!band.empty());
assert(band.size() == tileSizes.size());
// Check if the supplied for inst's are all successively nested.
for (unsigned i = 1, e = band.size(); i < e; i++) {
assert(band[i]->getParentInst() == band[i - 1]);
assert(band[i]->getInstruction()->getParentInst() ==
band[i - 1]->getInstruction());
}
auto origLoops = band;
ForInst *rootForInst = origLoops[0];
auto loc = rootForInst->getLoc();
OpPointer<AffineForOp> rootAffineForOp = origLoops[0];
auto loc = rootAffineForOp->getLoc();
// Note that width is at least one since band isn't empty.
unsigned width = band.size();
SmallVector<ForInst *, 12> newLoops(2 * width);
ForInst *innermostPointLoop;
SmallVector<OpPointer<AffineForOp>, 12> newLoops(2 * width);
OpPointer<AffineForOp> innermostPointLoop;
// The outermost among the loops as we add more..
auto *topLoop = rootForInst;
auto *topLoop = rootAffineForOp->getInstruction();
// Add intra-tile (or point) loops.
for (unsigned i = 0; i < width; i++) {
FuncBuilder b(topLoop);
// Loop bounds will be set later.
auto *pointLoop = b.createFor(loc, 0, 0);
auto pointLoop = b.create<AffineForOp>(loc, 0, 0);
pointLoop->createBody();
pointLoop->getBody()->getInstructions().splice(
pointLoop->getBody()->begin(), topLoop->getBlock()->getInstructions(),
topLoop);
newLoops[2 * width - 1 - i] = pointLoop;
topLoop = pointLoop;
topLoop = pointLoop->getInstruction();
if (i == 0)
innermostPointLoop = pointLoop;
}
@ -184,12 +189,13 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
for (unsigned i = width; i < 2 * width; i++) {
FuncBuilder b(topLoop);
// Loop bounds will be set later.
auto *tileSpaceLoop = b.createFor(loc, 0, 0);
auto tileSpaceLoop = b.create<AffineForOp>(loc, 0, 0);
tileSpaceLoop->createBody();
tileSpaceLoop->getBody()->getInstructions().splice(
tileSpaceLoop->getBody()->begin(),
topLoop->getBlock()->getInstructions(), topLoop);
newLoops[2 * width - i - 1] = tileSpaceLoop;
topLoop = tileSpaceLoop;
topLoop = tileSpaceLoop->getInstruction();
}
// Move the loop body of the original nest to the new one.
@ -201,8 +207,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
getIndexSet(band, &cst);
if (!cst.isHyperRectangular(0, width)) {
rootForInst->emitError("tiled code generation unimplemented for the"
"non-hyperrectangular case");
rootAffineForOp->emitError("tiled code generation unimplemented for the"
"non-hyperrectangular case");
return UtilResult::Failure;
}
@ -213,7 +219,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
}
// Erase the old loop nest.
rootForInst->erase();
rootAffineForOp->erase();
return UtilResult::Success;
}
@ -221,38 +227,36 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
// Identify valid and profitable bands of loops to tile. This is currently just
// a temporary placeholder to test the mechanics of tiled code generation.
// Returns all maximal outermost perfect loop nests to tile.
static void getTileableBands(Function *f,
std::vector<SmallVector<ForInst *, 6>> *bands) {
static void
getTileableBands(Function *f,
std::vector<SmallVector<OpPointer<AffineForOp>, 6>> *bands) {
// Get maximal perfect nest of 'for' insts starting from root (inclusive).
auto getMaximalPerfectLoopNest = [&](ForInst *root) {
SmallVector<ForInst *, 6> band;
ForInst *currInst = root;
auto getMaximalPerfectLoopNest = [&](OpPointer<AffineForOp> root) {
SmallVector<OpPointer<AffineForOp>, 6> band;
OpPointer<AffineForOp> currInst = root;
do {
band.push_back(currInst);
} while (currInst->getBody()->getInstructions().size() == 1 &&
(currInst = dyn_cast<ForInst>(&currInst->getBody()->front())));
(currInst = cast<OperationInst>(currInst->getBody()->front())
.dyn_cast<AffineForOp>()));
bands->push_back(band);
};
for (auto &block : *f) {
for (auto &inst : block) {
auto *forInst = dyn_cast<ForInst>(&inst);
if (!forInst)
continue;
getMaximalPerfectLoopNest(forInst);
}
}
for (auto &block : *f)
for (auto &inst : block)
if (auto forOp = cast<OperationInst>(inst).dyn_cast<AffineForOp>())
getMaximalPerfectLoopNest(forOp);
}
PassResult LoopTiling::runOnFunction(Function *f) {
std::vector<SmallVector<ForInst *, 6>> bands;
std::vector<SmallVector<OpPointer<AffineForOp>, 6>> bands;
getTileableBands(f, &bands);
// Temporary tile sizes.
unsigned tileSize =
clTileSize.getNumOccurrences() > 0 ? clTileSize : kDefaultTileSize;
for (const auto &band : bands) {
for (auto &band : bands) {
SmallVector<unsigned, 6> tileSizes(band.size(), tileSize);
if (tileCodeGen(band, tileSizes)) {
return failure();

View File

@ -21,6 +21,7 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@ -70,18 +71,19 @@ struct LoopUnroll : public FunctionPass {
const Optional<bool> unrollFull;
// Callback to obtain unroll factors; if this has a callable target, takes
// precedence over command-line argument or passed argument.
const std::function<unsigned(const ForInst &)> getUnrollFactor;
const std::function<unsigned(ConstOpPointer<AffineForOp>)> getUnrollFactor;
explicit LoopUnroll(
Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None,
const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr)
explicit LoopUnroll(Optional<unsigned> unrollFactor = None,
Optional<bool> unrollFull = None,
const std::function<unsigned(ConstOpPointer<AffineForOp>)>
&getUnrollFactor = nullptr)
: FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor),
unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {}
PassResult runOnFunction(Function *f) override;
/// Unroll this for inst. Returns false if nothing was done.
bool runOnForInst(ForInst *forInst);
bool runOnAffineForOp(OpPointer<AffineForOp> forOp);
static const unsigned kDefaultUnrollFactor = 4;
@ -96,7 +98,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> {
public:
// Store innermost loops as we walk.
std::vector<ForInst *> loops;
std::vector<OpPointer<AffineForOp>> loops;
// This method specialized to encode custom return logic.
using InstListType = llvm::iplist<Instruction>;
@ -111,20 +113,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
return hasInnerLoops;
}
bool walkForInstPostOrder(ForInst *forInst) {
bool hasInnerLoops =
walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end());
if (!hasInnerLoops)
loops.push_back(forInst);
return true;
}
bool walkOpInstPostOrder(OperationInst *opInst) {
bool hasInnerLoops = false;
for (auto &blockList : opInst->getBlockLists())
for (auto &block : blockList)
if (walkPostOrder(block.begin(), block.end()))
return true;
return false;
hasInnerLoops |= walkPostOrder(block.begin(), block.end());
if (opInst->isa<AffineForOp>()) {
if (!hasInnerLoops)
loops.push_back(opInst->cast<AffineForOp>());
return true;
}
return hasInnerLoops;
}
// FIXME: can't use base class method for this because that in turn would
@ -137,14 +136,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> {
public:
// Store short loops as we walk.
std::vector<ForInst *> loops;
std::vector<OpPointer<AffineForOp>> loops;
const unsigned minTripCount;
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
void visitForInst(ForInst *forInst) {
Optional<uint64_t> tripCount = getConstantTripCount(*forInst);
void visitOperationInst(OperationInst *opInst) {
auto forOp = opInst->dyn_cast<AffineForOp>();
if (!forOp)
return;
Optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (tripCount.hasValue() && tripCount.getValue() <= minTripCount)
loops.push_back(forInst);
loops.push_back(forOp);
}
};
@ -156,8 +158,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
// ones).
slg.walkPostOrder(f);
auto &loops = slg.loops;
for (auto *forInst : loops)
loopUnrollFull(forInst);
for (auto forOp : loops)
loopUnrollFull(forOp);
return success();
}
@ -172,8 +174,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
if (loops.empty())
break;
bool unrolled = false;
for (auto *forInst : loops)
unrolled |= runOnForInst(forInst);
for (auto forOp : loops)
unrolled |= runOnAffineForOp(forOp);
if (!unrolled)
// Break out if nothing was unrolled.
break;
@ -183,29 +185,30 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
/// Unrolls a 'for' inst. Returns true if the loop was unrolled, false
/// otherwise. The default unroll factor is 4.
bool LoopUnroll::runOnForInst(ForInst *forInst) {
bool LoopUnroll::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
// Use the function callback if one was provided.
if (getUnrollFactor) {
return loopUnrollByFactor(forInst, getUnrollFactor(*forInst));
return loopUnrollByFactor(forOp, getUnrollFactor(forOp));
}
// Unroll by the factor passed, if any.
if (unrollFactor.hasValue())
return loopUnrollByFactor(forInst, unrollFactor.getValue());
return loopUnrollByFactor(forOp, unrollFactor.getValue());
// Unroll by the command line factor if one was specified.
if (clUnrollFactor.getNumOccurrences() > 0)
return loopUnrollByFactor(forInst, clUnrollFactor);
return loopUnrollByFactor(forOp, clUnrollFactor);
// Unroll completely if full loop unroll was specified.
if (clUnrollFull.getNumOccurrences() > 0 ||
(unrollFull.hasValue() && unrollFull.getValue()))
return loopUnrollFull(forInst);
return loopUnrollFull(forOp);
// Unroll by four otherwise.
return loopUnrollByFactor(forInst, kDefaultUnrollFactor);
return loopUnrollByFactor(forOp, kDefaultUnrollFactor);
}
FunctionPass *mlir::createLoopUnrollPass(
int unrollFactor, int unrollFull,
const std::function<unsigned(const ForInst &)> &getUnrollFactor) {
const std::function<unsigned(ConstOpPointer<AffineForOp>)>
&getUnrollFactor) {
return new LoopUnroll(
unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor);

View File

@ -43,6 +43,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Passes.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@ -80,7 +81,7 @@ struct LoopUnrollAndJam : public FunctionPass {
unrollJamFactor(unrollJamFactor) {}
PassResult runOnFunction(Function *f) override;
bool runOnForInst(ForInst *forInst);
bool runOnAffineForOp(OpPointer<AffineForOp> forOp);
static char passID;
};
@ -95,47 +96,51 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
PassResult LoopUnrollAndJam::runOnFunction(Function *f) {
// Currently, just the outermost loop from the first loop nest is
// unroll-and-jammed by this pass. However, runOnForInst can be called on any
// for Inst.
// unroll-and-jammed by this pass. However, runOnAffineForOp can be called on
// any for Inst.
auto &entryBlock = f->front();
if (!entryBlock.empty())
if (auto *forInst = dyn_cast<ForInst>(&entryBlock.front()))
runOnForInst(forInst);
if (auto forOp =
cast<OperationInst>(entryBlock.front()).dyn_cast<AffineForOp>())
runOnAffineForOp(forOp);
return success();
}
/// Unroll and jam a 'for' inst. Default unroll jam factor is
/// kDefaultUnrollJamFactor. Return false if nothing was done.
bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) {
bool LoopUnrollAndJam::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
// Unroll and jam by the factor that was passed if any.
if (unrollJamFactor.hasValue())
return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue());
return loopUnrollJamByFactor(forOp, unrollJamFactor.getValue());
// Otherwise, unroll jam by the command-line factor if one was specified.
if (clUnrollJamFactor.getNumOccurrences() > 0)
return loopUnrollJamByFactor(forInst, clUnrollJamFactor);
return loopUnrollJamByFactor(forOp, clUnrollJamFactor);
// Unroll and jam by four otherwise.
return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor);
return loopUnrollJamByFactor(forOp, kDefaultUnrollJamFactor);
}
bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
bool mlir::loopUnrollJamUpToFactor(OpPointer<AffineForOp> forOp,
uint64_t unrollJamFactor) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollJamFactor)
return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue());
return loopUnrollJamByFactor(forInst, unrollJamFactor);
return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue());
return loopUnrollJamByFactor(forOp, unrollJamFactor);
}
/// Unrolls and jams this loop by the specified factor.
bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
uint64_t unrollJamFactor) {
// Gathers all maximal sub-blocks of instructions that do not themselves
// include a for inst (a instruction could have a descendant for inst though
// in its tree).
class JamBlockGatherer : public InstWalker<JamBlockGatherer> {
public:
using InstListType = llvm::iplist<Instruction>;
using InstWalker<JamBlockGatherer>::walk;
// Store iterators to the first and last inst of each sub-block found.
std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
@ -144,30 +149,30 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
void walk(InstListType::iterator Start, InstListType::iterator End) {
for (auto it = Start; it != End;) {
auto subBlockStart = it;
while (it != End && !isa<ForInst>(it))
while (it != End && !cast<OperationInst>(it)->isa<AffineForOp>())
++it;
if (it != subBlockStart)
subBlocks.push_back({subBlockStart, std::prev(it)});
// Process all for insts that appear next.
while (it != End && isa<ForInst>(it))
walkForInst(cast<ForInst>(it++));
while (it != End && cast<OperationInst>(it)->isa<AffineForOp>())
walk(&*it++);
}
}
};
assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
if (unrollJamFactor == 1 || forInst->getBody()->empty())
if (unrollJamFactor == 1 || forOp->getBody()->empty())
return false;
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (!mayBeConstantTripCount.hasValue() &&
getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0)
getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0)
return false;
auto lbMap = forInst->getLowerBoundMap();
auto ubMap = forInst->getUpperBoundMap();
auto lbMap = forOp->getLowerBoundMap();
auto ubMap = forOp->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as a Function in the general case). However, the right way to
@ -178,7 +183,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
// Same operand list for lower and upper bound for now.
// TODO(bondhugula): handle bounds with different sets of operands.
if (!forInst->matchingBoundOperandList())
if (!forOp->matchingBoundOperandList())
return false;
// If the trip count is lower than the unroll jam factor, no unroll jam.
@ -187,35 +192,38 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
mayBeConstantTripCount.getValue() < unrollJamFactor)
return false;
auto *forInst = forOp->getInstruction();
// Gather all sub-blocks to jam upon the loop being unrolled.
JamBlockGatherer jbg;
jbg.walkForInst(forInst);
jbg.walkOpInst(forInst);
auto &subBlocks = jbg.subBlocks;
// Generate the cleanup loop if trip count isn't a multiple of
// unrollJamFactor.
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
// Insert the cleanup loop right after 'forInst'.
// Insert the cleanup loop right after 'forOp'.
FuncBuilder builder(forInst->getBlock(),
std::next(Block::iterator(forInst)));
auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst));
cleanupForInst->setLowerBoundMap(
getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder));
auto cleanupAffineForOp =
cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>();
cleanupAffineForOp->setLowerBoundMap(
getCleanupLoopLowerBound(forOp, unrollJamFactor, &builder));
// The upper bound needs to be adjusted.
forInst->setUpperBoundMap(
getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder));
forOp->setUpperBoundMap(
getUnrolledLoopUpperBound(forOp, unrollJamFactor, &builder));
// Promote the loop body up if this has turned into a single iteration loop.
promoteIfSingleIteration(cleanupForInst);
promoteIfSingleIteration(cleanupAffineForOp);
}
// Scale the step of loop being unroll-jammed by the unroll-jam factor.
int64_t step = forInst->getStep();
forInst->setStep(step * unrollJamFactor);
int64_t step = forOp->getStep();
forOp->setStep(step * unrollJamFactor);
auto *forInstIV = forInst->getInductionVar();
auto *forOpIV = forOp->getInductionVar();
for (auto &subBlock : subBlocks) {
// Builder to insert unroll-jammed bodies. Insert right at the end of
// sub-block.
@ -227,13 +235,13 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
if (!forInstIV->use_empty()) {
if (!forOpIV->use_empty()) {
// iv' = iv + i, i = 1 to unrollJamFactor-1.
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto ivUnroll = builder.create<AffineApplyOp>(forInst->getLoc(),
bumpMap, forInstIV);
operandMapping.map(forInstIV, ivUnroll);
auto ivUnroll =
builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forOpIV);
operandMapping.map(forOpIV, ivUnroll);
}
// Clone the sub-block being unroll-jammed.
for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
@ -243,7 +251,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
}
// Promote the loop body up if this has turned into a single iteration loop.
promoteIfSingleIteration(forInst);
promoteIfSingleIteration(forOp);
return true;
}

View File

@ -24,6 +24,7 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
@ -246,7 +247,7 @@ public:
LowerAffinePass() : FunctionPass(&passID) {}
PassResult runOnFunction(Function *function) override;
bool lowerForInst(ForInst *forInst);
bool lowerAffineFor(OpPointer<AffineForOp> forOp);
bool lowerAffineIf(AffineIfOp *ifOp);
bool lowerAffineApply(AffineApplyOp *op);
@ -295,11 +296,11 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
// a nested loop). Induction variable modification is appended to the body SESE
// region that always loops back to the condition block.
//
// +--------------------------------+
// | <code before the ForInst> |
// | <compute initial %iv value> |
// | br cond(%iv) |
// +--------------------------------+
// +---------------------------------+
// | <code before the AffineForOp> |
// | <compute initial %iv value> |
// | br cond(%iv) |
// +---------------------------------+
// |
// -------| |
// | v v
@ -322,11 +323,12 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
// v
// +--------------------------------+
// | end: |
// | <code after the ForInst> |
// | <code after the AffineForOp> |
// +--------------------------------+
//
bool LowerAffinePass::lowerForInst(ForInst *forInst) {
auto loc = forInst->getLoc();
bool LowerAffinePass::lowerAffineFor(OpPointer<AffineForOp> forOp) {
auto loc = forOp->getLoc();
auto *forInst = forOp->getInstruction();
// Start by splitting the block containing the 'for' into two parts. The part
// before will get the init code, the part after will be the end point.
@ -339,23 +341,23 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
conditionBlock->insertBefore(endBlock);
auto *iv = conditionBlock->addArgument(IndexType::get(forInst->getContext()));
// Create the body block, moving the body of the forInst over to it.
// Create the body block, moving the body of the forOp over to it.
auto *bodyBlock = new Block();
bodyBlock->insertBefore(endBlock);
auto *oldBody = forInst->getBody();
auto *oldBody = forOp->getBody();
bodyBlock->getInstructions().splice(bodyBlock->begin(),
oldBody->getInstructions(),
oldBody->begin(), oldBody->end());
// The code in the body of the forInst now uses 'iv' as its indvar.
forInst->getInductionVar()->replaceAllUsesWith(iv);
// The code in the body of the forOp now uses 'iv' as its indvar.
forOp->getInductionVar()->replaceAllUsesWith(iv);
// Append the induction variable stepping logic and branch back to the exit
// condition block. Construct an affine expression f : (x -> x+step) and
// apply this expression to the induction variable.
FuncBuilder builder(bodyBlock);
auto affStep = builder.getAffineConstantExpr(forInst->getStep());
auto affStep = builder.getAffineConstantExpr(forOp->getStep());
auto affDim = builder.getAffineDimExpr(0);
auto stepped = expandAffineExpr(&builder, loc, affDim + affStep, iv, {});
if (!stepped)
@ -368,18 +370,18 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
builder.setInsertionPointToEnd(initBlock);
// Compute loop bounds.
SmallVector<Value *, 8> operands(forInst->getLowerBoundOperands());
SmallVector<Value *, 8> operands(forOp->getLowerBoundOperands());
auto lbValues = expandAffineMap(&builder, forInst->getLoc(),
forInst->getLowerBoundMap(), operands);
forOp->getLowerBoundMap(), operands);
if (!lbValues)
return true;
Value *lowerBound =
buildMinMaxReductionSeq(loc, CmpIPredicate::SGT, *lbValues, builder);
operands.assign(forInst->getUpperBoundOperands().begin(),
forInst->getUpperBoundOperands().end());
operands.assign(forOp->getUpperBoundOperands().begin(),
forOp->getUpperBoundOperands().end());
auto ubValues = expandAffineMap(&builder, forInst->getLoc(),
forInst->getUpperBoundMap(), operands);
forOp->getUpperBoundMap(), operands);
if (!ubValues)
return true;
Value *upperBound =
@ -394,7 +396,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
endBlock, ArrayRef<Value *>());
// Ok, we're done!
forInst->erase();
forOp->erase();
return false;
}
@ -614,28 +616,26 @@ PassResult LowerAffinePass::runOnFunction(Function *function) {
// Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
// We do this as a prepass to avoid invalidating the walker with our rewrite.
function->walkInsts([&](Instruction *inst) {
if (isa<ForInst>(inst))
instsToRewrite.push_back(inst);
auto op = dyn_cast<OperationInst>(inst);
if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>()))
auto op = cast<OperationInst>(inst);
if (op->isa<AffineApplyOp>() || op->isa<AffineForOp>() ||
op->isa<AffineIfOp>())
instsToRewrite.push_back(inst);
});
// Rewrite all of the ifs and fors. We walked the instructions in preorder,
// so we know that we will rewrite them in the same order.
for (auto *inst : instsToRewrite)
if (auto *forInst = dyn_cast<ForInst>(inst)) {
if (lowerForInst(forInst))
for (auto *inst : instsToRewrite) {
auto op = cast<OperationInst>(inst);
if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
if (lowerAffineIf(ifOp))
return failure();
} else {
auto op = cast<OperationInst>(inst);
if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
if (lowerAffineIf(ifOp))
return failure();
} else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
} else if (auto forOp = op->dyn_cast<AffineForOp>()) {
if (lowerAffineFor(forOp))
return failure();
}
} else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
return failure();
}
}
return success();
}

View File

@ -75,7 +75,7 @@
/// Implementation details
/// ======================
/// The current decisions made by the super-vectorization pass guarantee that
/// use-def chains do not escape an enclosing vectorized ForInst. In other
/// use-def chains do not escape an enclosing vectorized AffineForOp. In other
/// words, this pass operates on a scoped program slice. Furthermore, since we
/// do not vectorize in the presence of conditionals for now, sliced chains are
/// guaranteed not to escape the innermost scope, which has to be either the top
@ -285,13 +285,12 @@ static Value *substitute(Value *v, VectorType hwVectorType,
///
/// The general problem this function solves is as follows:
/// Assume a vector_transfer operation at the super-vector granularity that has
/// `l` enclosing loops (ForInst). Assume the vector transfer operation operates
/// on a MemRef of rank `r`, a super-vector of rank `s` and a hardware vector of
/// rank `h`.
/// For the purpose of illustration assume l==4, r==3, s==2, h==1 and that the
/// super-vector is vector<3x32xf32> and the hardware vector is vector<8xf32>.
/// Assume the following MLIR snippet after super-vectorization has been
/// applied:
/// `l` enclosing loops (AffineForOp). Assume the vector transfer operation
/// operates on a MemRef of rank `r`, a super-vector of rank `s` and a hardware
/// vector of rank `h`. For the purpose of illustration assume l==4, r==3, s==2,
/// h==1 and that the super-vector is vector<3x32xf32> and the hardware vector
/// is vector<8xf32>. Assume the following MLIR snippet after
/// super-vectorization has been applied:
///
/// ```mlir
/// for %i0 = 0 to %M {
@ -351,7 +350,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
SmallVector<AffineExpr, 8> affineExprs;
// TODO(ntv): support a concrete map and composition.
unsigned i = 0;
// The first numMemRefIndices correspond to ForInst that have not been
// The first numMemRefIndices correspond to AffineForOp that have not been
// vectorized, the transformation is the identity on those.
for (i = 0; i < numMemRefIndices; ++i) {
auto d_i = b->getAffineDimExpr(i);
@ -554,9 +553,6 @@ static bool instantiateMaterialization(Instruction *inst,
MaterializationState *state) {
LLVM_DEBUG(dbgs() << "\ninstantiate: " << *inst);
if (isa<ForInst>(inst))
return inst->emitError("NYI path ForInst");
// Create a builder here for unroll-and-jam effects.
FuncBuilder b(inst);
auto *opInst = cast<OperationInst>(inst);

View File

@ -21,11 +21,11 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
@ -38,15 +38,12 @@ using namespace mlir;
namespace {
struct PipelineDataTransfer : public FunctionPass,
InstWalker<PipelineDataTransfer> {
struct PipelineDataTransfer : public FunctionPass {
PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {}
PassResult runOnFunction(Function *f) override;
PassResult runOnForInst(ForInst *forInst);
PassResult runOnAffineForOp(OpPointer<AffineForOp> forOp);
// Collect all 'for' instructions.
void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
std::vector<ForInst *> forInsts;
std::vector<OpPointer<AffineForOp>> forOps;
static char passID;
};
@ -79,8 +76,8 @@ static unsigned getTagMemRefPos(const OperationInst &dmaInst) {
/// of the old memref by the new one while indexing the newly added dimension by
/// the loop IV of the specified 'for' instruction modulo 2. Returns false if
/// such a replacement cannot be performed.
static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) {
auto *forBody = forInst->getBody();
static bool doubleBuffer(Value *oldMemRef, OpPointer<AffineForOp> forOp) {
auto *forBody = forOp->getBody();
FuncBuilder bInner(forBody, forBody->begin());
bInner.setInsertionPoint(forBody, forBody->begin());
@ -101,6 +98,7 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) {
auto newMemRefType = doubleShape(oldMemRefType);
// Put together alloc operands for the dynamic dimensions of the memref.
auto *forInst = forOp->getInstruction();
FuncBuilder bOuter(forInst);
SmallVector<Value *, 4> allocOperands;
unsigned dynamicDimCount = 0;
@ -118,16 +116,16 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) {
// Create 'iv mod 2' value to index the leading dimension.
auto d0 = bInner.getAffineDimExpr(0);
int64_t step = forInst->getStep();
int64_t step = forOp->getStep();
auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0,
{d0.floorDiv(step) % 2}, {});
auto ivModTwoOp = bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap,
forInst->getInductionVar());
auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp->getLoc(), modTwoMap,
forOp->getInductionVar());
// replaceAllMemRefUsesWith will always succeed unless the forInst body has
// replaceAllMemRefUsesWith will always succeed unless the forOp body has
// non-deferencing uses of the memref.
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, {ivModTwoOp}, AffineMap(),
{}, &*forInst->getBody()->begin())) {
{}, &*forOp->getBody()->begin())) {
LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";);
ivModTwoOp->getInstruction()->erase();
@ -143,11 +141,14 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) {
// invalid (erased) when the outer loop is pipelined (the pipelined one gets
// deleted and replaced by a prologue, a new steady-state loop and an
// epilogue).
forInsts.clear();
walkPostOrder(f);
forOps.clear();
f->walkOpsPostOrder([&](OperationInst *opInst) {
if (auto forOp = opInst->dyn_cast<AffineForOp>())
forOps.push_back(forOp);
});
bool ret = false;
for (auto *forInst : forInsts) {
ret = ret | runOnForInst(forInst);
for (auto forOp : forOps) {
ret = ret | runOnAffineForOp(forOp);
}
return ret ? failure() : success();
}
@ -178,13 +179,13 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
// Identify matching DMA start/finish instructions to overlap computation with.
static void findMatchingStartFinishInsts(
ForInst *forInst,
OpPointer<AffineForOp> forOp,
SmallVectorImpl<std::pair<OperationInst *, OperationInst *>>
&startWaitPairs) {
// Collect outgoing DMA instructions - needed to check for dependences below.
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
for (auto &inst : *forInst->getBody()) {
for (auto &inst : *forOp->getBody()) {
auto *opInst = dyn_cast<OperationInst>(&inst);
if (!opInst)
continue;
@ -195,7 +196,7 @@ static void findMatchingStartFinishInsts(
}
SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts;
for (auto &inst : *forInst->getBody()) {
for (auto &inst : *forOp->getBody()) {
auto *opInst = dyn_cast<OperationInst>(&inst);
if (!opInst)
continue;
@ -227,7 +228,7 @@ static void findMatchingStartFinishInsts(
auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos());
bool escapingUses = false;
for (const auto &use : memref->getUses()) {
if (!forInst->getBody()->findAncestorInstInBlock(*use.getOwner())) {
if (!forOp->getBody()->findAncestorInstInBlock(*use.getOwner())) {
LLVM_DEBUG(llvm::dbgs()
<< "can't pipeline: buffer is live out of loop\n";);
escapingUses = true;
@ -251,17 +252,18 @@ static void findMatchingStartFinishInsts(
}
/// Overlap DMA transfers with computation in this loop. If successful,
/// 'forInst' is deleted, and a prologue, a new pipelined loop, and epilogue are
/// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
/// inserted right before where it was.
PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
auto mayBeConstTripCount = getConstantTripCount(*forInst);
PassResult
PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
auto mayBeConstTripCount = getConstantTripCount(forOp);
if (!mayBeConstTripCount.hasValue()) {
LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n");
return success();
}
SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs;
findMatchingStartFinishInsts(forInst, startWaitPairs);
findMatchingStartFinishInsts(forOp, startWaitPairs);
if (startWaitPairs.empty()) {
LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";);
@ -280,7 +282,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
auto *dmaStartInst = pair.first;
Value *oldMemRef = dmaStartInst->getOperand(
dmaStartInst->cast<DmaStartOp>()->getFasterMemPos());
if (!doubleBuffer(oldMemRef, forInst)) {
if (!doubleBuffer(oldMemRef, forOp)) {
// Normally, double buffering should not fail because we already checked
// that there are no uses outside.
LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
@ -302,7 +304,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
auto *dmaFinishInst = pair.second;
Value *oldTagMemRef =
dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst));
if (!doubleBuffer(oldTagMemRef, forInst)) {
if (!doubleBuffer(oldTagMemRef, forOp)) {
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
return success();
}
@ -315,7 +317,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
// Double buffering would have invalidated all the old DMA start/wait insts.
startWaitPairs.clear();
findMatchingStartFinishInsts(forInst, startWaitPairs);
findMatchingStartFinishInsts(forOp, startWaitPairs);
// Store shift for instruction for later lookup for AffineApplyOp's.
DenseMap<const Instruction *, unsigned> instShiftMap;
@ -342,16 +344,16 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
}
}
// Everything else (including compute ops and dma finish) are shifted by one.
for (const auto &inst : *forInst->getBody()) {
for (const auto &inst : *forOp->getBody()) {
if (instShiftMap.find(&inst) == instShiftMap.end()) {
instShiftMap[&inst] = 1;
}
}
// Get shifts stored in map.
std::vector<uint64_t> shifts(forInst->getBody()->getInstructions().size());
std::vector<uint64_t> shifts(forOp->getBody()->getInstructions().size());
unsigned s = 0;
for (auto &inst : *forInst->getBody()) {
for (auto &inst : *forOp->getBody()) {
assert(instShiftMap.find(&inst) != instShiftMap.end());
shifts[s++] = instShiftMap[&inst];
LLVM_DEBUG(
@ -363,13 +365,13 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
});
}
if (!isInstwiseShiftValid(*forInst, shifts)) {
if (!isInstwiseShiftValid(forOp, shifts)) {
// Violates dependences.
LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
return success();
}
if (instBodySkew(forInst, shifts)) {
if (instBodySkew(forOp, shifts)) {
LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";);
return success();
}

View File

@ -22,6 +22,7 @@
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/Passes.h"

View File

@ -21,6 +21,7 @@
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@ -39,22 +40,22 @@ using namespace mlir;
/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
/// the specified trip count, stride, and unroll factor. Returns nullptr when
/// the trip count can't be expressed as an affine expression.
AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst,
AffineMap mlir::getUnrolledLoopUpperBound(ConstOpPointer<AffineForOp> forOp,
unsigned unrollFactor,
FuncBuilder *builder) {
auto lbMap = forInst.getLowerBoundMap();
auto lbMap = forOp->getLowerBoundMap();
// Single result lower bound map only.
if (lbMap.getNumResults() != 1)
return AffineMap();
// Sometimes, the trip count cannot be expressed as an affine expression.
auto tripCount = getTripCountExpr(forInst);
auto tripCount = getTripCountExpr(forOp);
if (!tripCount)
return AffineMap();
AffineExpr lb(lbMap.getResult(0));
unsigned step = forInst.getStep();
unsigned step = forOp->getStep();
auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step;
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
@ -65,50 +66,51 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst,
/// bound 'lb' and with the specified trip count, stride, and unroll factor.
/// Returns an AffinMap with nullptr storage (that evaluates to false)
/// when the trip count can't be expressed as an affine expression.
AffineMap mlir::getCleanupLoopLowerBound(const ForInst &forInst,
AffineMap mlir::getCleanupLoopLowerBound(ConstOpPointer<AffineForOp> forOp,
unsigned unrollFactor,
FuncBuilder *builder) {
auto lbMap = forInst.getLowerBoundMap();
auto lbMap = forOp->getLowerBoundMap();
// Single result lower bound map only.
if (lbMap.getNumResults() != 1)
return AffineMap();
// Sometimes the trip count cannot be expressed as an affine expression.
AffineExpr tripCount(getTripCountExpr(forInst));
AffineExpr tripCount(getTripCountExpr(forOp));
if (!tripCount)
return AffineMap();
AffineExpr lb(lbMap.getResult(0));
unsigned step = forInst.getStep();
unsigned step = forOp->getStep();
auto newLb = lb + (tripCount - tripCount % unrollFactor) * step;
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
{newLb}, {});
}
/// Promotes the loop body of a forInst to its containing block if the forInst
/// Promotes the loop body of a forOp to its containing block if the forOp
/// was known to have a single iteration. Returns false otherwise.
// TODO(bondhugula): extend this for arbitrary affine bounds.
bool mlir::promoteIfSingleIteration(ForInst *forInst) {
Optional<uint64_t> tripCount = getConstantTripCount(*forInst);
bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) {
Optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (!tripCount.hasValue() || tripCount.getValue() != 1)
return false;
// TODO(mlir-team): there is no builder for a max.
if (forInst->getLowerBoundMap().getNumResults() != 1)
if (forOp->getLowerBoundMap().getNumResults() != 1)
return false;
// Replaces all IV uses to its single iteration value.
auto *iv = forInst->getInductionVar();
auto *iv = forOp->getInductionVar();
OperationInst *forInst = forOp->getInstruction();
if (!iv->use_empty()) {
if (forInst->hasConstantLowerBound()) {
if (forOp->hasConstantLowerBound()) {
auto *mlFunc = forInst->getFunction();
FuncBuilder topBuilder(mlFunc);
auto constOp = topBuilder.create<ConstantIndexOp>(
forInst->getLoc(), forInst->getConstantLowerBound());
forOp->getLoc(), forOp->getConstantLowerBound());
iv->replaceAllUsesWith(constOp);
} else {
const AffineBound lb = forInst->getLowerBound();
const AffineBound lb = forOp->getLowerBound();
SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst));
if (lb.getMap() == builder.getDimIdentityMap()) {
@ -124,8 +126,8 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) {
// Move the loop body instructions to the loop's containing block.
auto *block = forInst->getBlock();
block->getInstructions().splice(Block::iterator(forInst),
forInst->getBody()->getInstructions());
forInst->erase();
forOp->getBody()->getInstructions());
forOp->erase();
return true;
}
@ -133,13 +135,10 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) {
/// their body into the containing Block.
void mlir::promoteSingleIterationLoops(Function *f) {
// Gathers all innermost loops through a post order pruned walk.
class LoopBodyPromoter : public InstWalker<LoopBodyPromoter> {
public:
void visitForInst(ForInst *forInst) { promoteIfSingleIteration(forInst); }
};
LoopBodyPromoter fsw;
fsw.walkPostOrder(f);
f->walkOpsPostOrder([](OperationInst *inst) {
if (auto forOp = inst->dyn_cast<AffineForOp>())
promoteIfSingleIteration(forOp);
});
}
/// Generates a 'for' inst with the specified lower and upper bounds while
@ -149,19 +148,22 @@ void mlir::promoteSingleIterationLoops(Function *f) {
/// the pair specifies the shift applied to that group of instructions; note
/// that the shift is multiplied by the loop step before being applied. Returns
/// nullptr if the generated loop simplifies to a single iteration one.
static ForInst *
static OpPointer<AffineForOp>
generateLoop(AffineMap lbMap, AffineMap ubMap,
const std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>>
&instGroupQueue,
unsigned offset, ForInst *srcForInst, FuncBuilder *b) {
unsigned offset, OpPointer<AffineForOp> srcForInst,
FuncBuilder *b) {
SmallVector<Value *, 4> lbOperands(srcForInst->getLowerBoundOperands());
SmallVector<Value *, 4> ubOperands(srcForInst->getUpperBoundOperands());
assert(lbMap.getNumInputs() == lbOperands.size());
assert(ubMap.getNumInputs() == ubOperands.size());
auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap,
ubOperands, ubMap, srcForInst->getStep());
auto loopChunk =
b->create<AffineForOp>(srcForInst->getLoc(), lbOperands, lbMap,
ubOperands, ubMap, srcForInst->getStep());
loopChunk->createBody();
auto *loopChunkIV = loopChunk->getInductionVar();
auto *srcIV = srcForInst->getInductionVar();
@ -176,7 +178,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
// Generate the remapping if the shift is not zero: remappedIV = newIV -
// shift.
if (!srcIV->use_empty() && shift != 0) {
auto b = FuncBuilder::getForInstBodyBuilder(loopChunk);
FuncBuilder b(loopChunk->getBody());
auto ivRemap = b.create<AffineApplyOp>(
srcForInst->getLoc(),
b.getSingleDimShiftAffineMap(
@ -191,7 +193,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
}
}
if (promoteIfSingleIteration(loopChunk))
return nullptr;
return OpPointer<AffineForOp>();
return loopChunk;
}
@ -210,28 +212,29 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
// asserts preservation of SSA dominance. A check for that as well as that for
// memory-based depedence preservation check rests with the users of this
// method.
UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
UtilResult mlir::instBodySkew(OpPointer<AffineForOp> forOp,
ArrayRef<uint64_t> shifts,
bool unrollPrologueEpilogue) {
if (forInst->getBody()->empty())
if (forOp->getBody()->empty())
return UtilResult::Success;
// If the trip counts aren't constant, we would need versioning and
// conditional guards (or context information to prevent such versioning). The
// better way to pipeline for such loops is to first tile them and extract
// constant trip count "full tiles" before applying this.
auto mayBeConstTripCount = getConstantTripCount(*forInst);
auto mayBeConstTripCount = getConstantTripCount(forOp);
if (!mayBeConstTripCount.hasValue()) {
LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";);
return UtilResult::Success;
}
uint64_t tripCount = mayBeConstTripCount.getValue();
assert(isInstwiseShiftValid(*forInst, shifts) &&
assert(isInstwiseShiftValid(forOp, shifts) &&
"shifts will lead to an invalid transformation\n");
int64_t step = forInst->getStep();
int64_t step = forOp->getStep();
unsigned numChildInsts = forInst->getBody()->getInstructions().size();
unsigned numChildInsts = forOp->getBody()->getInstructions().size();
// Do a linear time (counting) sort for the shifts.
uint64_t maxShift = 0;
@ -249,7 +252,7 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
// body of the 'for' inst.
std::vector<std::vector<Instruction *>> sortedInstGroups(maxShift + 1);
unsigned pos = 0;
for (auto &inst : *forInst->getBody()) {
for (auto &inst : *forOp->getBody()) {
auto shift = shifts[pos++];
sortedInstGroups[shift].push_back(&inst);
}
@ -259,17 +262,17 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
// Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
// loop generated as the prologue and the last as epilogue and unroll these
// fully.
ForInst *prologue = nullptr;
ForInst *epilogue = nullptr;
OpPointer<AffineForOp> prologue;
OpPointer<AffineForOp> epilogue;
// Do a sweep over the sorted shifts while storing open groups in a
// vector, and generating loop portions as necessary during the sweep. A block
// of instructions is paired with its shift.
std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>> instGroupQueue;
auto origLbMap = forInst->getLowerBoundMap();
auto origLbMap = forOp->getLowerBoundMap();
uint64_t lbShift = 0;
FuncBuilder b(forInst);
FuncBuilder b(forOp->getInstruction());
for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) {
// If nothing is shifted by d, continue.
if (sortedInstGroups[d].empty())
@ -280,19 +283,19 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
// The interval for which the loop needs to be generated here is:
// [lbShift, min(lbShift + tripCount, d)) and the body of the
// loop needs to have all instructions in instQueue in that order.
ForInst *res;
OpPointer<AffineForOp> res;
if (lbShift + tripCount * step < d * step) {
res = generateLoop(
b.getShiftedAffineMap(origLbMap, lbShift),
b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
instGroupQueue, 0, forInst, &b);
instGroupQueue, 0, forOp, &b);
// Entire loop for the queued inst groups generated, empty it.
instGroupQueue.clear();
lbShift += tripCount * step;
} else {
res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
b.getShiftedAffineMap(origLbMap, d), instGroupQueue,
0, forInst, &b);
0, forOp, &b);
lbShift = d * step;
}
if (!prologue && res)
@ -312,60 +315,63 @@ UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step;
epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
b.getShiftedAffineMap(origLbMap, ubShift),
instGroupQueue, i, forInst, &b);
instGroupQueue, i, forOp, &b);
lbShift = ubShift;
if (!prologue)
prologue = epilogue;
}
// Erase the original for inst.
forInst->erase();
forOp->erase();
if (unrollPrologueEpilogue && prologue)
loopUnrollFull(prologue);
if (unrollPrologueEpilogue && !epilogue && epilogue != prologue)
if (unrollPrologueEpilogue && !epilogue &&
epilogue->getInstruction() != prologue->getInstruction())
loopUnrollFull(epilogue);
return UtilResult::Success;
}
/// Unrolls this loop completely.
bool mlir::loopUnrollFull(ForInst *forInst) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
bool mlir::loopUnrollFull(OpPointer<AffineForOp> forOp) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (mayBeConstantTripCount.hasValue()) {
uint64_t tripCount = mayBeConstantTripCount.getValue();
if (tripCount == 1) {
return promoteIfSingleIteration(forInst);
return promoteIfSingleIteration(forOp);
}
return loopUnrollByFactor(forInst, tripCount);
return loopUnrollByFactor(forOp, tripCount);
}
return false;
}
/// Unrolls and jams this loop by the specified factor or by the trip count (if
/// constant) whichever is lower.
bool mlir::loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
bool mlir::loopUnrollUpToFactor(OpPointer<AffineForOp> forOp,
uint64_t unrollFactor) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollFactor)
return loopUnrollByFactor(forInst, mayBeConstantTripCount.getValue());
return loopUnrollByFactor(forInst, unrollFactor);
return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue());
return loopUnrollByFactor(forOp, unrollFactor);
}
/// Unrolls this loop by the specified factor. Returns true if the loop
/// is successfully unrolled.
bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
bool mlir::loopUnrollByFactor(OpPointer<AffineForOp> forOp,
uint64_t unrollFactor) {
assert(unrollFactor >= 1 && "unroll factor should be >= 1");
if (unrollFactor == 1)
return promoteIfSingleIteration(forInst);
return promoteIfSingleIteration(forOp);
if (forInst->getBody()->empty())
if (forOp->getBody()->empty())
return false;
auto lbMap = forInst->getLowerBoundMap();
auto ubMap = forInst->getUpperBoundMap();
auto lbMap = forOp->getLowerBoundMap();
auto ubMap = forOp->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as a Function in the general case). However, the right way to
@ -376,10 +382,10 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
// Same operand list for lower and upper bound for now.
// TODO(bondhugula): handle bounds with different operand lists.
if (!forInst->matchingBoundOperandList())
if (!forOp->matchingBoundOperandList())
return false;
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
// If the trip count is lower than the unroll factor, no unrolled body.
// TODO(bondhugula): option to specify cleanup loop unrolling.
@ -388,10 +394,12 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
return false;
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
if (getLargestDivisorOfTripCount(*forInst) % unrollFactor != 0) {
OperationInst *forInst = forOp->getInstruction();
if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst));
auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst));
auto clLbMap = getCleanupLoopLowerBound(*forInst, unrollFactor, &builder);
auto cleanupForInst =
cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>();
auto clLbMap = getCleanupLoopLowerBound(forOp, unrollFactor, &builder);
assert(clLbMap &&
"cleanup loop lower bound map for single result bound maps can "
"always be determined");
@ -401,50 +409,50 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
// Adjust upper bound.
auto unrolledUbMap =
getUnrolledLoopUpperBound(*forInst, unrollFactor, &builder);
getUnrolledLoopUpperBound(forOp, unrollFactor, &builder);
assert(unrolledUbMap &&
"upper bound map can alwayys be determined for an unrolled loop "
"with single result bounds");
forInst->setUpperBoundMap(unrolledUbMap);
forOp->setUpperBoundMap(unrolledUbMap);
}
// Scale the step of loop being unrolled by unroll factor.
int64_t step = forInst->getStep();
forInst->setStep(step * unrollFactor);
int64_t step = forOp->getStep();
forOp->setStep(step * unrollFactor);
// Builder to insert unrolled bodies right after the last instruction in the
// body of 'forInst'.
FuncBuilder builder(forInst->getBody(), forInst->getBody()->end());
// body of 'forOp'.
FuncBuilder builder(forOp->getBody(), forOp->getBody()->end());
// Keep a pointer to the last instruction in the original block so that we
// know what to clone (since we are doing this in-place).
Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end());
Block::iterator srcBlockEnd = std::prev(forOp->getBody()->end());
// Unroll the contents of 'forInst' (append unrollFactor-1 additional copies).
auto *forInstIV = forInst->getInductionVar();
// Unroll the contents of 'forOp' (append unrollFactor-1 additional copies).
auto *forOpIV = forOp->getInductionVar();
for (unsigned i = 1; i < unrollFactor; i++) {
BlockAndValueMapping operandMap;
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
if (!forInstIV->use_empty()) {
if (!forOpIV->use_empty()) {
// iv' = iv + 1/2/3...unrollFactor-1;
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto ivUnroll =
builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInstIV);
operandMap.map(forInstIV, ivUnroll);
builder.create<AffineApplyOp>(forOp->getLoc(), bumpMap, forOpIV);
operandMap.map(forOpIV, ivUnroll);
}
// Clone the original body of 'forInst'.
for (auto it = forInst->getBody()->begin(); it != std::next(srcBlockEnd);
// Clone the original body of 'forOp'.
for (auto it = forOp->getBody()->begin(); it != std::next(srcBlockEnd);
it++) {
builder.clone(*it, operandMap);
}
}
// Promote the loop body up if this has turned into a single iteration loop.
promoteIfSingleIteration(forInst);
promoteIfSingleIteration(forOp);
return true;
}

View File

@ -22,6 +22,7 @@
#include "mlir/Transforms/Utils.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Dominance.h"
@ -278,8 +279,8 @@ void mlir::createAffineComputationSlice(
/// Folds the specified (lower or upper) bound to a constant if possible
/// considering its operands. Returns false if the folding happens for any of
/// the bounds, true otherwise.
bool mlir::constantFoldBounds(ForInst *forInst) {
auto foldLowerOrUpperBound = [forInst](bool lower) {
bool mlir::constantFoldBounds(OpPointer<AffineForOp> forInst) {
auto foldLowerOrUpperBound = [&forInst](bool lower) {
// Check if the bound is already a constant.
if (lower && forInst->hasConstantLowerBound())
return true;

View File

@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Analysis/VectorAnalysis.h"
@ -252,9 +253,9 @@ using namespace mlir;
/// ==========
/// The algorithm proceeds in a few steps:
/// 1. defining super-vectorization patterns and matching them on the tree of
/// ForInst. A super-vectorization pattern is defined as a recursive data
/// structures that matches and captures nested, imperfectly-nested loops
/// that have a. comformable loop annotations attached (e.g. parallel,
/// AffineForOp. A super-vectorization pattern is defined as a recursive
/// data structures that matches and captures nested, imperfectly-nested
/// loops that have a. comformable loop annotations attached (e.g. parallel,
/// reduction, vectoriable, ...) as well as b. all contiguous load/store
/// operations along a specified minor dimension (not necessarily the
/// fastest varying) ;
@ -279,11 +280,11 @@ using namespace mlir;
/// it by its vector form. Otherwise, if the scalar value is a constant,
/// it is vectorized into a splat. In all other cases, vectorization for
/// the pattern currently fails.
/// e. if everything under the root ForInst in the current pattern vectorizes
/// properly, we commit that loop to the IR. Otherwise we discard it and
/// restore a previously cloned version of the loop. Thanks to the
/// recursive scoping nature of matchers and captured patterns, this is
/// transparently achieved by a simple RAII implementation.
/// e. if everything under the root AffineForOp in the current pattern
/// vectorizes properly, we commit that loop to the IR. Otherwise we
/// discard it and restore a previously cloned version of the loop. Thanks
/// to the recursive scoping nature of matchers and captured patterns,
/// this is transparently achieved by a simple RAII implementation.
/// f. vectorization is applied on the next pattern in the list. Because
/// pattern interference avoidance is not yet implemented and that we do
/// not support further vectorizing an already vector load we need to
@ -667,12 +668,13 @@ namespace {
struct VectorizationStrategy {
SmallVector<int64_t, 8> vectorSizes;
DenseMap<ForInst *, unsigned> loopToVectorDim;
DenseMap<Instruction *, unsigned> loopToVectorDim;
};
} // end anonymous namespace
static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern,
static void vectorizeLoopIfProfitable(Instruction *loop,
unsigned depthInPattern,
unsigned patternDepth,
VectorizationStrategy *strategy) {
assert(patternDepth > depthInPattern &&
@ -704,13 +706,13 @@ static bool analyzeProfitability(ArrayRef<NestedMatch> matches,
unsigned depthInPattern, unsigned patternDepth,
VectorizationStrategy *strategy) {
for (auto m : matches) {
auto *loop = cast<ForInst>(m.getMatchedInstruction());
bool fail = analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1,
patternDepth, strategy);
if (fail) {
return fail;
}
vectorizeLoopIfProfitable(loop, depthInPattern, patternDepth, strategy);
vectorizeLoopIfProfitable(m.getMatchedInstruction(), depthInPattern,
patternDepth, strategy);
}
return false;
}
@ -855,8 +857,8 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
/// Coarsens the loops bounds and transforms all remaining load and store
/// operations into the appropriate vector_transfer.
static bool vectorizeForInst(ForInst *loop, int64_t step,
VectorizationState *state) {
static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step,
VectorizationState *state) {
using namespace functional;
loop->setStep(step);
@ -873,7 +875,7 @@ static bool vectorizeForInst(ForInst *loop, int64_t step,
};
auto loadAndStores = matcher::Op(notVectorizedThisPattern);
SmallVector<NestedMatch, 8> loadAndStoresMatches;
loadAndStores.match(loop, &loadAndStoresMatches);
loadAndStores.match(loop->getInstruction(), &loadAndStoresMatches);
for (auto ls : loadAndStoresMatches) {
auto *opInst = cast<OperationInst>(ls.getMatchedInstruction());
auto load = opInst->dyn_cast<LoadOp>();
@ -898,7 +900,7 @@ static bool vectorizeForInst(ForInst *loop, int64_t step,
static FilterFunctionType
isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
return [fastestVaryingMemRefDimension](const Instruction &forInst) {
const auto &loop = cast<ForInst>(forInst);
auto loop = cast<OperationInst>(forInst).cast<AffineForOp>();
return isVectorizableLoopAlongFastestVaryingMemRefDim(
loop, fastestVaryingMemRefDimension);
};
@ -912,7 +914,8 @@ static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches,
/// if all vectorizations in `childrenMatches` have already succeeded
/// recursively in DFS post-order.
static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) {
ForInst *loop = cast<ForInst>(oneMatch.getMatchedInstruction());
auto *loopInst = oneMatch.getMatchedInstruction();
auto loop = cast<OperationInst>(loopInst)->cast<AffineForOp>();
auto childrenMatches = oneMatch.getMatchedChildren();
// 1. DFS postorder recursion, if any of my children fails, I fail too.
@ -924,7 +927,7 @@ static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) {
// 2. This loop may have been omitted from vectorization for various reasons
// (e.g. due to the performance model or pattern depth > vector size).
auto it = state->strategy->loopToVectorDim.find(loop);
auto it = state->strategy->loopToVectorDim.find(loopInst);
if (it == state->strategy->loopToVectorDim.end()) {
return false;
}
@ -939,10 +942,10 @@ static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) {
// exploratory tradeoffs (see top of the file). Apply coarsening, i.e.:
// | ub -> ub
// | step -> step * vectorSize
LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForInst by " << vectorSize
LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForOp by " << vectorSize
<< " : ");
LLVM_DEBUG(loop->print(dbgs()));
return vectorizeForInst(loop, loop->getStep() * vectorSize, state);
LLVM_DEBUG(loopInst->print(dbgs()));
return vectorizeAffineForOp(loop, loop->getStep() * vectorSize, state);
}
/// Non-root pattern iterates over the matches at this level, calls doVectorize
@ -1186,7 +1189,8 @@ static bool vectorizeOperations(VectorizationState *state) {
/// Each root may succeed independently but will otherwise clean after itself if
/// anything below it fails.
static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
auto *loop = cast<ForInst>(m.getMatchedInstruction());
auto loop =
cast<OperationInst>(m.getMatchedInstruction())->cast<AffineForOp>();
VectorizationState state;
state.strategy = strategy;
@ -1197,17 +1201,20 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
// vectorizable. If a pattern is not vectorizable anymore, we just skip it.
// TODO(ntv): implement a non-greedy profitability analysis that keeps only
// non-intersecting patterns.
if (!isVectorizableLoop(*loop)) {
if (!isVectorizableLoop(loop)) {
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable");
return true;
}
FuncBuilder builder(loop); // builder to insert in place of loop
ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop));
auto *loopInst = loop->getInstruction();
FuncBuilder builder(loopInst);
auto clonedLoop =
cast<OperationInst>(builder.clone(*loopInst))->cast<AffineForOp>();
auto fail = doVectorize(m, &state);
/// Sets up error handling for this root loop. This is how the root match
/// maintains a clone for handling failure and restores the proper state via
/// RAII.
ScopeGuard sg2([&fail, loop, clonedLoop]() {
ScopeGuard sg2([&fail, &loop, &clonedLoop]() {
if (fail) {
loop->getInductionVar()->replaceAllUsesWith(
clonedLoop->getInductionVar());
@ -1291,8 +1298,8 @@ PassResult Vectorize::runOnFunction(Function *f) {
if (fail) {
continue;
}
auto *loop = cast<ForInst>(m.getMatchedInstruction());
vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy);
vectorizeLoopIfProfitable(m.getMatchedInstruction(), 0, patternDepth,
&strategy);
// TODO(ntv): if pattern does not apply, report it; alter the
// cost/benefit.
fail = vectorizeRootMatch(m, &strategy);

View File

@ -204,7 +204,7 @@ func @illegaltype(i0) // expected-error {{invalid integer width}}
// -----
func @malformed_for_percent() {
for i = 1 to 10 { // expected-error {{expected SSA identifier for the loop variable}}
for i = 1 to 10 { // expected-error {{expected SSA operand}}
// -----
@ -222,18 +222,18 @@ func @malformed_for_to() {
func @incomplete_for() {
for %i = 1 to 10 step 2
} // expected-error {{expected '{' before instruction list}}
} // expected-error {{expected '{' to begin block list}}
// -----
func @nonconstant_step(%1 : i32) {
for %2 = 1 to 5 step %1 { // expected-error {{expected integer}}
for %2 = 1 to 5 step %1 { // expected-error {{expected type}}
// -----
func @for_negative_stride() {
for %i = 1 to 10 step -1
} // expected-error {{step has to be a positive integer}}
} // expected-error@-1 {{expected step to be representable as a positive signed integer}}
// -----
@ -510,7 +510,7 @@ func @undefined_function() {
func @bound_symbol_mismatch(%N : index) {
for %i = #map1(%N) to 100 {
// expected-error@-1 {{symbol operand count and affine map symbol count must match}}
// expected-error@-1 {{symbol operand count and integer set symbol count must match}}
}
return
}
@ -521,78 +521,7 @@ func @bound_symbol_mismatch(%N : index) {
func @bound_dim_mismatch(%N : index) {
for %i = #map1(%N, %N)[%N] to 100 {
// expected-error@-1 {{dim operand count and affine map dim count must match}}
}
return
}
// -----
#map1 = (i)[j] -> (i+j)
func @invalid_dim_nested(%N : index) {
for %i = 1 to 100 {
%a = "foo"(%N) : (index)->(index)
for %j = 1 to #map1(%a)[%i] {
// expected-error@-1 {{value '%a' cannot be used as a dimension id}}
}
}
return
}
// -----
#map1 = (i)[j] -> (i+j)
func @invalid_dim_affine_apply(%N : index) {
for %i = 1 to 100 {
%a = "foo"(%N) : (index)->(index)
%w = affine_apply (i)->(i+1) (%a)
for %j = 1 to #map1(%w)[%i] {
// expected-error@-1 {{value '%w' cannot be used as a dimension id}}
}
}
return
}
// -----
#map1 = (i)[j] -> (i+j)
func @invalid_symbol_iv(%N : index) {
for %i = 1 to 100 {
%a = "foo"(%N) : (index)->(index)
for %j = 1 to #map1(%N)[%i] {
// expected-error@-1 {{value '%i' cannot be used as a symbol}}
}
}
return
}
// -----
#map1 = (i)[j] -> (i+j)
func @invalid_symbol_nested(%N : index) {
for %i = 1 to 100 {
%a = "foo"(%N) : (index)->(index)
for %j = 1 to #map1(%N)[%a] {
// expected-error@-1 {{value '%a' cannot be used as a symbol}}
}
}
return
}
// -----
#map1 = (i)[j] -> (i+j)
func @invalid_symbol_affine_apply(%N : index) {
for %i = 1 to 100 {
%w = affine_apply (i)->(i+1) (%i)
for %j = 1 to #map1(%i)[%w] {
// expected-error@-1 {{value '%w' cannot be used as a symbol}}
}
// expected-error@-1 {{dim operand count and integer set dim count must match}}
}
return
}
@ -601,7 +530,7 @@ func @invalid_symbol_affine_apply(%N : index) {
func @large_bound() {
for %i = 1 to 9223372036854775810 {
// expected-error@-1 {{bound or step is too large for index}}
// expected-error@-1 {{integer constant out of range for attribute}}
}
return
}
@ -609,7 +538,7 @@ func @large_bound() {
// -----
func @max_in_upper_bound(%N : index) {
for %i = 1 to max (i)->(N, 100) { //expected-error {{expected SSA operand}}
for %i = 1 to max (i)->(N, 100) { //expected-error {{expected type}}
}
return
}
@ -617,7 +546,7 @@ func @max_in_upper_bound(%N : index) {
// -----
func @step_typo() {
for %i = 1 to 100 step -- 1 { //expected-error {{expected integer}}
for %i = 1 to 100 step -- 1 { //expected-error {{expected constant integer}}
}
return
}

View File

@ -12,9 +12,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) {
// CHECK: constant 4 : index loc(callsite("foo" at "mysource.cc":10:8))
%2 = constant 4 : index loc(callsite("foo" at "mysource.cc":10:8))
// CHECK: for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8])
for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) {
}
// CHECK: } loc(fused["foo", "mysource.cc":10:8])
for %i0 = 0 to 8 {
} loc(fused["foo", "mysource.cc":10:8])
// CHECK: } loc(fused<"myPass">["foo", "foo2"])
if #set0(%2) {

View File

@ -230,7 +230,7 @@ func @complex_loops() {
func @triang_loop(%arg0: index, %arg1: memref<?x?xi32>) {
%c = constant 0 : i32 // CHECK: %c0_i32 = constant 0 : i32
for %i0 = 1 to %arg0 { // CHECK: for %i0 = 1 to %arg0 {
for %i1 = %i0 to %arg0 { // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to %arg0 {
for %i1 = (d0)[]->(d0)(%i0)[] to %arg0 { // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to %arg0 {
store %c, %arg1[%i0, %i1] : memref<?x?xi32> // CHECK: store %c0_i32, %arg1[%i0, %i1]
} // CHECK: }
} // CHECK: }
@ -254,7 +254,7 @@ func @loop_bounds(%N : index) {
// CHECK: for %i0 = %0 to %arg0
for %i = %s to %N {
// CHECK: for %i1 = #map{{[0-9]+}}(%i0) to 0
for %j = %i to 0 step 1 {
for %j = (d0)[]->(d0)(%i)[] to 0 step 1 {
// CHECK: %1 = affine_apply #map{{.*}}(%i0, %i1)[%0]
%w1 = affine_apply(d0, d1)[s0] -> (d0+d1) (%i, %j) [%s]
// CHECK: %2 = affine_apply #map{{.*}}(%i0, %i1)[%0]
@ -764,23 +764,3 @@ func @verbose_if(%N: index) {
}
return
}
// CHECK-LABEL: func @verbose_for
func @verbose_for(%arg0 : index, %arg1 : index) {
// CHECK-NEXT: %0 = "for"() {lb: 1, ub: 10} : () -> index {
%a = "for"() {lb: 1, ub: 10 } : () -> index {
// CHECK-NEXT: %1 = "for"() {lb: 1, step: 2, ub: 100} : () -> index {
%b = "for"() {lb: 1, ub: 100, step: 2 } : () -> index {
// CHECK-NEXT: %2 = "for"(%arg0, %arg1) : (index, index) -> index {
%c = "for"(%arg0, %arg1) : (index, index) -> index {
// CHECK-NEXT: %3 = "for"(%arg0) {ub: 100} : (index) -> index {
%d = "for"(%arg0) {ub: 100 } : (index) -> index {
}
}
}
}
return
}

View File

@ -17,9 +17,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) {
// CHECK-NEXT: at mysource3.cc:100:10
%3 = constant 4 : index loc(callsite("foo" at callsite("mysource1.cc":10:8 at callsite("mysource2.cc":13:8 at "mysource3.cc":100:10))))
// CHECK: for %i0 = 0 to 8 ["foo", mysource.cc:10:8]
for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) {
}
// CHECK: } ["foo", mysource.cc:10:8]
for %i0 = 0 to 8 {
} loc(fused["foo", "mysource.cc":10:8])
// CHECK: } <"myPass">["foo", "foo2"]
if #set0(%2) {

View File

@ -9,9 +9,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) {
// CHECK: "foo"() : () -> i32 loc(unknown)
%1 = "foo"() : () -> i32 loc("foo")
// CHECK: for %i0 = 0 to 8 loc(unknown)
for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) {
}
// CHECK: } loc(unknown)
for %i0 = 0 to 8 {
} loc(fused["foo", "mysource.cc":10:8])
// CHECK: } loc(unknown)
%2 = constant 4 : index

View File

@ -40,6 +40,9 @@
// UNROLL-BY-4: [[MAP7:#map[0-9]+]] = (d0) -> (d0 + 5)
// UNROLL-BY-4: [[MAP8:#map[0-9]+]] = (d0) -> (d0 + 10)
// UNROLL-BY-4: [[MAP9:#map[0-9]+]] = (d0) -> (d0 + 15)
// UNROLL-BY-4: [[MAP10:#map[0-9]+]] = (d0) -> (0)
// UNROLL-BY-4: [[MAP11:#map[0-9]+]] = (d0) -> (d0)
// UNROLL-BY-4: [[MAP12:#map[0-9]+]] = ()[s0] -> (0)
// CHECK-LABEL: func @loop_nest_simplest() {
func @loop_nest_simplest() {
@ -432,7 +435,7 @@ func @loop_nest_single_iteration_after_unroll(%N: index) {
// UNROLL-BY-4-LABEL: func @loop_nest_operand1() {
func @loop_nest_operand1() {
// UNROLL-BY-4: for %i0 = 0 to 100 step 2 {
// UNROLL-BY-4-NEXT: for %i1 = (d0) -> (0)(%i0) to #map{{[0-9]+}}(%i0) step 4
// UNROLL-BY-4-NEXT: for %i1 = [[MAP10]](%i0) to #map{{[0-9]+}}(%i0) step 4
// UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
@ -452,7 +455,7 @@ func @loop_nest_operand1() {
// UNROLL-BY-4-LABEL: func @loop_nest_operand2() {
func @loop_nest_operand2() {
// UNROLL-BY-4: for %i0 = 0 to 100 step 2 {
// UNROLL-BY-4-NEXT: for %i1 = (d0) -> (d0)(%i0) to #map{{[0-9]+}}(%i0) step 4 {
// UNROLL-BY-4-NEXT: for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 {
// UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
@ -474,7 +477,7 @@ func @loop_nest_operand2() {
func @loop_nest_operand3() {
// UNROLL-BY-4: for %i0 = 0 to 100 step 2 {
for %i = 0 to 100 step 2 {
// UNROLL-BY-4: for %i1 = (d0) -> (d0)(%i0) to #map{{[0-9]+}}(%i0) step 4 {
// UNROLL-BY-4: for %i1 = [[MAP11]](%i0) to #map{{[0-9]+}}(%i0) step 4 {
// UNROLL-BY-4-NEXT: %0 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32
@ -492,7 +495,7 @@ func @loop_nest_operand3() {
func @loop_nest_operand4(%N : index) {
// UNROLL-BY-4: for %i0 = 0 to 100 {
for %i = 0 to 100 {
// UNROLL-BY-4: for %i1 = ()[s0] -> (0)()[%arg0] to #map{{[0-9]+}}()[%arg0] step 4 {
// UNROLL-BY-4: for %i1 = [[MAP12]]()[%arg0] to #map{{[0-9]+}}()[%arg0] step 4 {
// UNROLL-BY-4: %0 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %1 = "foo"() : () -> i32
// UNROLL-BY-4-NEXT: %2 = "foo"() : () -> i32